Skip to content

Conversation

Liangliang-Ma
Copy link
Collaborator

@Liangliang-Ma Liangliang-Ma commented Aug 22, 2025

developing grouped_gemm_bf16 for Llama4-scout fused moe.
based on #11 cutlass env.

jikunshang and others added 10 commits August 1, 2025 00:59
Signed-off-by: Kunshang Ji <[email protected]>
* add cutlass

Signed-off-by: Kunshang Ji <[email protected]>

* fix import

Signed-off-by: Kunshang Ji <[email protected]>

---------

Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Fix acc and oom issue
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
CMakeLists.txt Outdated
FetchContent_Declare(
cutlass-sycl
GIT_REPOSITORY https://github.com/intel/cutlass-sycl
GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why using private forked cutlass-sycl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rebase to cutlass-sycl/main.

Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Copy link
Collaborator

@baodii baodii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Ma, Liangliang <[email protected]>
@Liangliang-Ma Liangliang-Ma changed the title [WIP] Grouped gemm cutlass Grouped gemm cutlass Sep 25, 2025
Copy link
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor comments. can address in next PR.

FUSEDMOE_AVAILABLE = True
except ImportError as e:
FUSEDMOE_UNAVAILABLE_REASON = str(e)
FUSEDMOE_AVAILABLE = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we need log this error even raise error directly.

torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset, N=n, K=k, **gemm_args)


def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may not be a good interface... it's ok to keep it for now.

expert_output = input @ weight.T
ref.append(expert_output)
pre_token_sum += cur_token_num
ref = torch.cat(ref, dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better make this a reference function

return result


@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some mini scope so we can run on simulator.

@jikunshang jikunshang merged commit 6c6be64 into vllm-project:main Sep 25, 2025
3 checks passed
@Liangliang-Ma
Copy link
Collaborator Author

some minor comments. can address in next PR.

Got it. Thx

Comment on lines +435 to +443
using EpilogueOp =
cutlass::epilogue::fusion::LinearCombination<float_t, float_t>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape,
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto,
float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1,
EpilogueDispatchPolicy, EpilogueOp>::CollectiveOp;
Copy link

@sanchitintel sanchitintel Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on xe_builder.cpp & this code, it seems you used intel/sycl-tla#505 as reference.
Currently, it's using EpilogueBuilder, but I'll replace that code with CollectiveEpilogue, which is more configurable.

Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may used #https://github.com/intel/cutlass-sycl/blob/b0cb10e655d8f9b1d0474e9538a82d218f74c694/benchmarks/gemm/gemm_configuration_sycl.hpp#L137C3-L137C87 as reference too. I will check your code to not be same in next time. Thanks!

Copy link

@sanchitintel sanchitintel Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly attributed the reference in the description of intel/sycl-tla#505.
Not only does it give credit to the original author, but it makes maintenance easier.

Besides, I had told you on Sep 11 (Sep 12 for you) that I had fixed that issue. I created a PR for it the same day.

offset.append(0)

########### gemm1 ##################
input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2)
Copy link

@sanchitintel sanchitintel Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use ColumnMajor B in the GroupGEMM kernel, so that you wouldn't have to transpose B.
To use column-major B (to avoid transposing weights), you can use a different copy atom for transposed loads . Coincidentally, because of how thread-value assignment works in the copy-atoms, the transpose copy atom for 16-bit (or 8-bits, for that matter) dtypes will load data in VNNI format (which is also true for atoms ending in _N).

However, if the latency of transpose of B + GEMM with RowMajor B is lower than GEMM with columnMajor B (highly unlikely), then you might want to retain this approach.

FWIW, once 32x32 transpose copy-atom for BF16 is added in cutlass, perf of GEMM with columnMajor B will become a bit better.

Thanks!

@sanchitintel
Copy link

When I commented today morning, I already had this page open from yesterday, and didn't know that it had already been merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants