-
Notifications
You must be signed in to change notification settings - Fork 14
Grouped gemm cutlass #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grouped gemm cutlass #22
Conversation
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
* add cutlass Signed-off-by: Kunshang Ji <[email protected]> * fix import Signed-off-by: Kunshang Ji <[email protected]> --------- Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Fix acc and oom issue
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
CMakeLists.txt
Outdated
FetchContent_Declare( | ||
cutlass-sycl | ||
GIT_REPOSITORY https://github.com/intel/cutlass-sycl | ||
GIT_REPOSITORY https://github.com/Liangliang-Ma/cutlass-sycl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why using private forked cutlass-sycl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will rebase to cutlass-sycl/main.
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
Signed-off-by: Ma, Liangliang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Ma, Liangliang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor comments. can address in next PR.
FUSEDMOE_AVAILABLE = True | ||
except ImportError as e: | ||
FUSEDMOE_UNAVAILABLE_REASON = str(e) | ||
FUSEDMOE_AVAILABLE = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we need log this error even raise error directly.
torch.ops._xpu_C.cutlass_grouped_gemm(offset=offset, N=n, K=k, **gemm_args) | ||
|
||
|
||
def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may not be a good interface... it's ok to keep it for now.
expert_output = input @ weight.T | ||
ref.append(expert_output) | ||
pre_token_sum += cur_token_num | ||
ref = torch.cat(ref, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better make this a reference function
return result | ||
|
||
|
||
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add some mini scope so we can run on simulator.
Got it. Thx |
using EpilogueOp = | ||
cutlass::epilogue::fusion::LinearCombination<float_t, float_t>; | ||
|
||
using CollectiveEpilogue = | ||
typename cutlass::epilogue::collective::CollectiveBuilder< | ||
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape, | ||
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, | ||
float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1, | ||
EpilogueDispatchPolicy, EpilogueOp>::CollectiveOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on xe_builder.cpp
& this code, it seems you used intel/sycl-tla#505 as reference.
Currently, it's using EpilogueBuilder
, but I'll replace that code with CollectiveEpilogue
, which is more configurable.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you may used #https://github.com/intel/cutlass-sycl/blob/b0cb10e655d8f9b1d0474e9538a82d218f74c694/benchmarks/gemm/gemm_configuration_sycl.hpp#L137C3-L137C87 as reference too. I will check your code to not be same in next time. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I explicitly attributed the reference in the description of intel/sycl-tla#505.
Not only does it give credit to the original author, but it makes maintenance easier.
Besides, I had told you on Sep 11 (Sep 12 for you) that I had fixed that issue. I created a PR for it the same day.
offset.append(0) | ||
|
||
########### gemm1 ################## | ||
input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use ColumnMajor
B
in the GroupGEMM kernel, so that you wouldn't have to transpose B
.
To use column-major B
(to avoid transposing weights), you can use a different copy atom for transposed loads . Coincidentally, because of how thread-value assignment works in the copy-atoms, the transpose copy atom for 16-bit (or 8-bits, for that matter) dtypes will load data in VNNI format (which is also true for atoms ending in _N
).
However, if the latency of transpose of B + GEMM with RowMajor B
is lower than GEMM with columnMajor B
(highly unlikely), then you might want to retain this approach.
FWIW, once 32x32 transpose copy-atom for BF16 is added in cutlass, perf of GEMM with columnMajor B
will become a bit better.
Thanks!
When I commented today morning, I already had this page open from yesterday, and didn't know that it had already been merged. |
developing grouped_gemm_bf16 for Llama4-scout fused moe.
based on #11 cutlass env.