-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numeric issue for llama_8b_fp8 model on hip #19859
Comments
Might be worth trying with |
module__initializer_0_dispatch_0.mlir So the first dispatch output NAN. With |
I tried to dump tensor data inputs/outputs via compile command used run module:
https://gist.github.com/drprajap/e0b5c399e4a2047e42c5b616cb99db85 |
I tried to do a binary search over the torch IR; The NaNs are starting from here, i.e., fp8 matmul: Command to run inp1.bin: https://sharkpublic.blob.core.windows.net/sharkpublic/prashant/inp1.bin |
Based on @pashu123 found faulty.mlir, I tried to delete the fp8 torch.aten.mm(faulty_inp2.mlir / faulty_inp1.mlir) and only return the input bin, I found there is no NaN after iree-run-module. So I think there is nothing wrong with the input.bin inp1.bin/inp2.bin prashant given, the NAN issue is from the fp8 torch.aten.mm.
|
@AmosLewis can you attach the dump from adding |
btw, the most relevant people here are in UK. I asked @nirvedhmeshram to take a look, but will likely not get much eyes on this till Monday. |
Here's the dump: https://gist.github.com/pashu123/07f94ba18756b36891828b88a56f9a55 |
Here is a little experiment I did to convince myself that this is a overflow issue in the matmul the truncation seems to be uncessarily going to f8 and then going back to bf16 which I believe causes the overflow
|
@nirvedhmeshram I just looked at the dump, this is a patten first happen after op fusion(ElementwiseOpFusionPass), any idea to work around it? Should it be fixed by add patten to fuse it into |
I asked @MaheshRavishankar about it and he said we can't do anything in the compiler in a foolproof way. We need to understand why the model is doing f8 matmul and then always casting it back to bf16 and if something can be done at the torch level as we see this everywhere in the model
|
If the value goes out of range after the truncate, there is really nothing that the compiler can do here. In effect it is maintaining program semantics cause the semantics of the program is to go out of bounds. I think the issue is with the quantization. The weights havent been quantized correctly to keep the result in range of f8 values. |
I believe this model is doing bf16 for attention based on how the quantization team quantized it so likely casting up to that dtype and back for each SDPA op. @dan-garvey would know the specifics on how/why that is being done. |
The model is quantized to work in the following way: do some matmul in fp8. The intrinsic for fp8 matmul on mi300 outputs into fp32 -> truncate to bf16 |
I see that's not what the torch IR snippet I showed above will do since the aten.mm has ->f8 in it. It is first truncating to f8 and then extending to bf16. I think we want aten.mm ops with(f8,f8) ->bf16, then it will do the right thing. |
Ok, so we need a custom kernel that does that. I think they did the same thing for punet. I'll try and get new ir going over the weekend. |
nod-ai/shark-ai#896 |
@dan-garvey It bring new issues. Indexput's input and value to be put should be same type, but you change make them different, input b16 but value f8.
indexput.torch.mlir
Besides, for the torch.aten.mm, only the dim changed from static to dynamic, I didn't see and type change. |
The numeric is fixed with Dan local shark-ai new generated llama_8b_f8.mlir, . The generated llama_8b_f8.vmfb, llama_8b_f8_prefill.tracy, llama_8b_f8_decode.tracy. BTW the nod-ai/shark-ai#896 current fix is not enough for now(a98a332). I copy the mlir file from Dan's local directory. |
What happened?
Follow up of #19809
Here is the input mlir llama_8b_fp8.mlir
inputs.bin can be cp from folder (SharkMI300X, /sharedfile/prefill/ and /sharedfile/decode/) or use the following link
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_seq_block_ids_1_1_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_seq_lens_1_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_token_ids_1_32_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_cache_state_128_2097152_f8E4M3FNUZ.bin
I tried to create the dispatch and and locate where the NAN start. The NAN happened from the very beginning. I found them at module___builtin_fill_i64.mlir, module__initializer_0_dispatch_0.mlir, module_prefill_bs1$async_dispatch_0.mlir, I don't know the order of this 3, they are named dispatch0, so I list all of them here.
The inputs.bin I verified with https://hexed.it/, none of them are NAN. @benvanik

@MaheshRavishankar Could you assign anyone to fix this numeric?
Steps to reproduce your issue
Get dispatch files are something like:
module___builtin_fill_i64.mlir
module__initializer_0_dispatch_0.mlir ... module__initializer_10_dispatch_0.mlir
module_prefill_bs1$async_dispatch_0.mlir ... module_prefill_bs1$async_dispatch_806.mlir
module_decode_bs1$async_dispatch_0.mlir ... module_decode_bs1$async_dispatch_680.mlir
What component(s) does this issue relate to?
Compiler
Version information
Additional context
No response
The text was updated successfully, but these errors were encountered: