-
Notifications
You must be signed in to change notification settings - Fork 631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 3.1 8B fp16 TP8 sharded fails to compile for CPU and GPU #19263
Llama 3.1 8B fp16 TP8 sharded fails to compile for CPU and GPU #19263
Comments
About the CPU compilation error. I made a fix when exporting for the unsharded case where we want no device affinities. This is a sharded variant. At a first glance argument and global parameter affinities look fine. It is probably something with the |
The GPU error is in an attention dispatch failing. It is going down the You can run this with
Here is the full dump. Looks like it has some dynamic shapes so I am guessing vector distribute bailed on it. |
I think @sogartar suggested we not compile with |
Yes I was just using that to be concise, you can use the new flags and get the same error too. |
On the GPU side, this looks like it is coming because of inner unit dims for K2 dimension of attention. We could either collapse those unit dims to make it work, or I can send a patch tommorow to add support for multiple M/N dimension for intrinsic targetting. |
@aviator19941 for some context, is this a regression or something that never worked? Edit: I thought this might have to do with %expanded_8645 = tensor.expand_shape %28622 [[0, 1, 2], [3, 4]] output_shape [4, %21, 1, 1, 128]
: tensor<?x128xf16> into tensor<4x?x1x1x128xf16>
%collapsed_8653 = tensor.collapse_shape %expanded_8645 [[0], [1, 2, 3], [4]]
: tensor<4x?x1x1x128xf16> into tensor<4x?x128xf16> to a single expand shape. |
This is a regression @IanWood1 . "I can send a patch tommorow to add support for multiple M/N dimension for intrinsic targetting" <- this seems ile rt thing to do. |
This should fix it |
This change adds `linalgExtExpansionFn` to limit sinking of `collapse_shape` ops through `iree_linalg_ext.attention` only when the k2 dimensions are not expanded by the reshape fusion. Currently, GPU Codegen cannot support unit dims on the k2 dimensions, so any `collapse_shape` that expands out unit dimensions on these dims will cause compilation errors. This fixes the unit dim error in #19263 but it uncovered furtherk but unrelated, compilation errors tracked in #19377. Signed-off-by: Ian Wood <[email protected]>
What happened?
When I try to compile the sharded Llama 3.1 8b fp16 IR for CPU or GPU:
I get this error for CPU:
https://gist.github.com/aviator19941/82bceb2624571d446da0964440790fde
and this error for GPU:
https://gist.github.com/aviator19941/89761b3bbb6ace5a6945de667e6d1e39
I tried to use these flags that were suggested to be used when compiling Llama as well:
--iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=false --iree-hal-memoization=true --iree-opt-strip-assertions
Steps to reproduce your issue
../iree-build-no-trace/tools/iree-compile 8b_f16_tp8_decomposed.mlir -o=8b_f16_tp8_decomposed_cpu.vmfb --iree-hal-target-device=llvm-cpu[0] --iree-hal-target-device=llvm-cpu[1] --iree-hal-target-device=llvm-cpu[2] --iree-hal-target-device=llvm-cpu[3] --iree-hal-target-device=llvm-cpu[4] --iree-hal-target-device=llvm-cpu[5] --iree-hal-target-device=llvm-cpu[6] --iree-hal-target-device=llvm-cpu[7]
../iree-build-no-trace/tools/iree-compile 8b_f16_tp8_decomposed.mlir --iree-hip-target=gfx942 -o=8b_f16_tp8_decomposed.vmfb --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7]
What component(s) does this issue relate to?
No response
Version information
iree-base-compiler 3.1.0rc20241121
Additional context
No response
The text was updated successfully, but these errors were encountered: