Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use 64 as the block size of moe_align kernel to avoid lds out of limit #303

Merged
merged 3 commits into from
Dec 5, 2024

Conversation

charlifu
Copy link

@charlifu charlifu commented Dec 4, 2024

The moe_align kernel uses the max value between num of experts and warp size as the gpu block size. This leads to lds out of limit when handling model with big number of experts.

This PR sets the block size to always be warp size, and changes the kernel to handle the situation when thread number is smaller than number of experts.

@charlifu charlifu requested review from gshtras and shajrawi December 4, 2024 21:08
Copy link
Collaborator

@shajrawi shajrawi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos Charlie!

@charlifu charlifu changed the title Always use 64 as the block size to avoid lds out of limit Always use 64 as the block size of moe_align kernel to avoid lds out of limit Dec 4, 2024
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be a nested loop?
Could a reduce operation be more efficient here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each thread is doing its own reduce op on an independent data array. In this case warp reduce primitive does not apply.

@charlifu charlifu merged commit b414ae9 into develop Dec 5, 2024
7 checks passed
@gshtras gshtras deleted the charlifu/fix_moe_align_expert_num_too_big branch December 7, 2024 03:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants