-
Notifications
You must be signed in to change notification settings - Fork 13.4k
vulkan: Update topk_moe fusion to handle gpt's late softmax #16656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
CC @am17an I've included the ggml_check_edges change in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand what this change is doing, but how do I test it? The topk_moe tests pass before and after this change. Which model architectures correspond to the three modes?
| GGML_OP_GET_ROWS, GGML_OP_RESHAPE, | ||
| GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; | ||
|
|
||
| //node #963 ( SOFT_MAX): ffn_moe_probs-15 ( 64K) [Vulka ] use=2: ffn_moe_logits-15 ( 64K) [Vulka ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vulka?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logging is from ggml_backend_sched_print_assignments (set GGML_SCHED_DEBUG=2, run llama-bench with -v, I also edit the code not to skip views), it truncates the name to keep the formatting aligned.
Usually I put a debug statement printing the number of nodes fused. We'll need to come up with a better way to assert that the nodes were actually fused |
I've added some logging in the latest commit that I use to verify fusion and the effects of graph_optimize. You can see the whole sequence of ops without a sync in between, which implies the fusion is working. Early softmax w/norm: qwen3 |
Based on #16649.