Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numeric issue for llama_8b_fp8 model on hip #19859

Open
AmosLewis opened this issue Jan 30, 2025 · 20 comments · May be fixed by nod-ai/shark-ai#896
Open

Numeric issue for llama_8b_fp8 model on hip #19859

AmosLewis opened this issue Jan 30, 2025 · 20 comments · May be fixed by nod-ai/shark-ai#896
Labels
bug 🐞 Something isn't working

Comments

@AmosLewis
Copy link
Contributor

AmosLewis commented Jan 30, 2025

What happened?

Follow up of #19809

/home/chi/src/iree-build/tools/iree-run-module \
--hip_use_streams=true \
--module=fp8.vmfb \
--parameters=model=fp8.irpa \
--device=hip://4 \
--function=prefill_bs1 \
--input=1x32xi64=@/sharedfile/prefill/prefill_token_ids_1_32.bin \
--input=1xi64=@/sharedfile/prefill/prefill_seq_lens_1.bin \
--input=1x1xi64=@/sharedfile/prefill/prefill_seq_block_ids_1_1.bin \
--input=128x2097152xf8E4M3FNUZ=@/sharedfile/prefill/prefill_cache_state_128_2097152.bin
EXEC @prefill_bs1
result[0]: hal.buffer_view
1x32x128256xbf16=[[NAN NAN NAAN NAN NAN...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]]

Here is the input mlir llama_8b_fp8.mlir
inputs.bin can be cp from folder (SharkMI300X, /sharedfile/prefill/ and /sharedfile/decode/) or use the following link

https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_seq_block_ids_1_1_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_seq_lens_1_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_token_ids_1_32_i64.bin
https://sharkpublic.blob.core.windows.net/sharkpublic/chi/llama/prefill/prefill_cache_state_128_2097152_f8E4M3FNUZ.bin

I tried to create the dispatch and and locate where the NAN start. The NAN happened from the very beginning. I found them at module___builtin_fill_i64.mlir, module__initializer_0_dispatch_0.mlir, module_prefill_bs1$async_dispatch_0.mlir, I don't know the order of this 3, they are named dispatch0, so I list all of them here.

The inputs.bin I verified with https://hexed.it/, none of them are NAN. @benvanik
Image

@MaheshRavishankar Could you assign anyone to fix this numeric?

Steps to reproduce your issue

  1. Checkout iree to this commit
commit 3f713f5f4743ddb6715e4ea4a361784b54489e5a (HEAD -> main, upstream/main)
Author: Jakub Kuderski <[email protected]>
Date:   Wed Jan 29 12:49:42 2025 -0500

    [ROCm] Add mi325x to known targets (#19846)
  1. cmake
 cmake -G Ninja -B ../iree-build   -S . -DCMAKE_BUILD_TYPE=Debug   \
-DIREE_ENABLE_ASSERTIONS=ON   -DCMAKE_C_COMPILER=clang   \
-DCMAKE_CXX_COMPILER=clang++   -DIREE_ENABLE_RUNTIME_TRACING=ON   \
-DIREE_BUILD_TRACY=OFF   -DIREE_ENABLE_LLD=ON   \
-DIREE_BUILD_PYTHON_BINDINGS=ON   \
-DPython3_EXECUTABLE="$(which python3)"  \
-DIREE_TARGET_BACKEND_CUDA=OFF -DIREE_HAL_DRIVER_HIP=ON \
-DIREE_TARGET_BACKEND_ROCM=ON .

cmake --build ../iree-build
  1. Generate vmfb, along with all dispatch, find a dispatch function name as a break point.
/home/chi/src/iree-build/tools/iree-compile fp8.mlir \  --iree-hip-target=gfx942 \  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions \ 
  --iree-hal-dump-executable-sources-to=/home/chi/src/test/llama/dispatch/ \
  -o=fp8.vmfb

Get dispatch files are something like:
module___builtin_fill_i64.mlir
module__initializer_0_dispatch_0.mlir ... module__initializer_10_dispatch_0.mlir
module_prefill_bs1$async_dispatch_0.mlir ... module_prefill_bs1$async_dispatch_806.mlir
module_decode_bs1$async_dispatch_0.mlir ... module_decode_bs1$async_dispatch_680.mlir

  1. create a vmfb run to certain dispatch 0
/home/chi/src/iree-build/tools/iree-compile fp8.mlir \
  --iree-hip-target=gfx942 \
  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions \
  --mlir-print-debuginfo \
  --iree-util-zero-fill-elided-attrs \
  --iree-flow-break-dispatch=@_initializer_0_dispatch_0 \
  -o=fp8__initializer_0_dispatch_0.vmfb
  1. iree-run-module to the chosen dispatch 0
 /home/chi/src/iree-build/tools/iree-run-module \
--hip_use_streams=true \
--module=fp8__initializer_0_dispatch_0.vmfb \
--parameters=model=fp8.irpa \
--device=hip://4 \
--function=prefill_bs1 \
--input=1x32xi64=@/sharedfile/prefill/prefill_token_ids_1_32.bin \
--input=1xi64=@/sharedfile/prefill/prefill_seq_lens_1.bin \
--input=1x1xi64=@/sharedfile/prefill/prefill_seq_block_ids_1_1.bin \
--input=128x2097152xf8E4M3FNUZ=@/sharedfile/prefill/prefill_cache_state_128_2097152.bin
EXEC @prefill_bs1
result[0]: hal.buffer_view
1x32x128256xbf16=[[NAN NAN NAN...
  1. Repeat step 4&5 on other dispatch

What component(s) does this issue relate to?

Compiler

Version information

commit 3f713f5f4743ddb6715e4ea4a361784b54489e5a (HEAD -> main, upstream/main)
Author: Jakub Kuderski <[email protected]>
Date:   Wed Jan 29 12:49:42 2025 -0500

    [ROCm] Add mi325x to known targets (#19846)

Additional context

No response

@AWoloszyn
Copy link
Contributor

From your Tracy tracy: This is the first dispatch that is run.
Image

@AWoloszyn
Copy link
Contributor

Might be worth trying with --parameter_mode="mmap" or --parameter_mode="preload" to rule out an issue with file uploading.

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Jan 30, 2025

From your Tracy tracy: This is the first dispatch that is run. Image

module__initializer_0_dispatch_0.mlir So the first dispatch output NAN.

With --parameter_mode="mmap" or --parameter_mode="preload", same nan output
parameter_mode_nan.log

@drprajap
Copy link

I tried to dump tensor data inputs/outputs via --iree-flow-trace-dispatch-tensors and it did give me some non-NaN values after that specific dispatch in question - initializer_0_dispatch_0_elementwise_broadcast_64_i64

compile command used
../../iree-build-trace/tools/iree-compile llama_8b_fp8_tp8.mlir --iree-hip-target=gfx942 --iree-hal-target-device=hip --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hal-memoization=true --iree-opt-strip-assertions --mlir-print-debuginfo --iree-util-zero-fill-elided-attrs --iree-flow-break-dispatch=@_initializer_0_dispatch_0 -o=fp8__initializer_0_dispatch_0.vmfb --iree-flow-trace-dispatch-tensors --iree-hal-dump-executable-files-to=01_30_dump

run module:
/home/diprajap/workspace/iree-build-trace/tools/iree-run-module --hip_use_streams=true --module=/home/diprajap/workspace/artifacts/8b_fp8_tp8/fp8__initializer_0_dispatch_0.vmfb --parameters=model=llama3_8b_fp8.irpa --device=hip://4 --function=prefill_bs1 --input=1x32xi64=@/home/diprajap/workspace/artifacts/8b_fp8_tp8/prefill/prefill_token_ids_1_32.bin --input=1xi64=@/home/diprajap/workspace/artifacts/8b_fp8_tp8/prefill/prefill_seq_lens_1.bin --input=1x1xi64=@/home/diprajap/workspace/artifacts/8b_fp8_tp8/prefill/prefill_seq_block_ids_1_1.bin --input=128x2097152xf8E4M3FNUZ=@/home/diprajap/workspace/artifacts/8b_fp8_tp8/prefill/prefill_cache_state_128_2097152.bin


=== _initializer_0_dispatch_0::_initializer_0_dispatch_0_elementwise_broadcast_64_i64 outputs ===
64xi64=0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96 98 100 102 104 106 108 110 112 114 116 118 120 122 124 126

https://gist.github.com/drprajap/e0b5c399e4a2047e42c5b616cb99db85

@pashu123
Copy link
Contributor

@AmosLewis
Copy link
Contributor Author

https://gist.github.com/pashu123/9a092f901d670ce4a5b898c72221e0d7 using this python test_nan.py inp2.bin contains NaN. I have to go back and check :).

Based on @pashu123 found faulty.mlir, I tried to delete the fp8 torch.aten.mm(faulty_inp2.mlir / faulty_inp1.mlir) and only return the input bin, I found there is no NaN after iree-run-module. So I think there is nothing wrong with the input.bin inp1.bin/inp2.bin prashant given, the NAN issue is from the fp8 torch.aten.mm.
Another thing I found is after fp8 torch.aten.mm, there are lots of interger beside the NAN output in the result bf16, which does not make sense to me either.
@IanNod mentioned we need to prioritize and push this work today from lead sync. @MaheshRavishankar could you find more codegen guys continue on this in US timezone?

faulty_inp2.mlir

func.func @faulty(%arg0: !torch.vtensor<[32,4096],bf16>, %arg1: !torch.vtensor<[4096,4096],bf16>) -> !torch.vtensor<[32,4096],bf16> {
  %int26 = torch.constant.int 26
  %int15 = torch.constant.int 15
  return %arg0 : !torch.vtensor<[32,4096],bf16>
}
iree-compile faulty_inp2.mlir  --iree-hip-target=gfx942  --iree-hal-target-device=hip   --iree-dispatch-creation-enable-aggressive-fusion=true   --iree-global-opt-propagate-transposes=true   --iree-opt-aggressively-propagate-transposes=true   --iree-opt-data-tiling=false   --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'   --iree-hal-indirect-command-buffers=true   --iree-stream-resource-memory-model=discrete   --iree-hal-memoization=true   --iree-opt-strip-assertions -o=init2.vmfb

iree-run-module --hip_use_streams=true --device=hip://0 --function=faulty --module=init2.vmfb [email protected] [email protected]
EXEC @faulty
result[0]: hal.buffer_view
32x4096xbf16=[0.0437012 -2.15625 -0.90625 -0.464844 -14.875 0.636719 0.000701904 0.0437012 0.472656 -0.322266 -1.39844 0.105957 0.234375 0.345703 -0.0629883 -0.041748 -0.0471191 -0.0164795 0.917969 0.882812 -1.41406 0.0107422 0.0196533 -5.1875 0.0493164 0.0446777 -0.244141 0.474609 -2.73438 0.511719 0.302734 -0.566406 0.169922 -0.933594 0.390625 -1.51562 -1.1875 0.722656 -0.480469 -7.40625 0.0380859 -0.000541687 -2.125 -0.435547 0.267578 -1.85938 -0.0354004 0.00860596 -0.259766 -0.558594 -3.04688 0.0141602 3.1875 -0.515625 -0.0991211 -0.144531 0.578125 -0.137695 -0.105957 -3.82812 0.000888824 -0.000442505 -1.51562 0.0144653 -2.29688 -0.189453 0.376953 0.816406 -0.0397949 -0.0344238 -0.000118732 -0.000284195 -0.00212097 2.46875 1.63281 -1.72656 -0.000383377 1.85938 -0.0017395 2.07812 0.519531 -1.16406 -3.85938 -0.151367 -0.00476074 1.17969 0.0125732 -5.375 0.24707 -0.1875 -0.546875 -6.5 0.392578 0.048584 -0.160156 1.49219 0.0537109 -0.203125 -0.0296631 -0.141602 -0.298828 0.0424805 0.300781 13.4375 -1.89844 0.026001 -0.0235596 2.375 0.00442505 -0.0274658 -0.0524902 -4.90625 0.224609 1.4375 -1.0625 -1.03906 -0.542969 -2.15625 -0.398438 -1.86719 -3.09375 -1.95312 -0.582031 0.166992 -1.17969 -0.0108643 -2.1875 -0.800781 -3.17188 0.322266 0.621094 0.0014267 1.11719 23.375 0.0126343 1.08594 -0.332031 0.699219 -0.0015564 0.239258 -0.804688 0.026123 0.3125 -0.0133667 6.53125 0.0678711 -0.863281 -0.847656 -1.41406 0.0306396 0.00848389 -0.146484 0.113281 0.0649414 0.219727 -0.15918 0.0844727 -0.285156 -0.314453 0.431641 -0.523438 0.769531 -0.349609 0.0517578 -0.765625 1.85156 1.04688 -0.115234 -0.455078 -0.00038147 -0.0106201 0.0212402 0.189453 -0.046875 2.75 -0.0230713 0.00357056 0.00030899 0.199219 0.318359 0.263672 -0.202148 -0.0732422 -1.92188 1.125 -0.363281 1.35156 0.0380859 0.597656 -0.0766602 0.285156 0.00056839 0.000873566 -1.63281 10.25 0.0319824 0.208008 -0.125 -0.0515137 10.4375 -0.0544434 0.746094 -0.00338745 0.371094 1.11719 5.8125 -0.000823975 -0.0673828 -0.318359 -0.124512 0.902344 0.0976562 8.4375 -0.380859 -0.730469 -0.347656 0.820312 -0.0756836 0.111816 -0.0163574 -2.0625 -0.204102 -0.0218506 -0.867188 -0.0270996 -4.59375 -0.207031 2.78125 2.29688 -0.421875 0.00689697 0.00805664 -0.585938 -0.00166321 -0.9375 -0.00308228 0.102051 0.757812 0.0893555 1.07031 0.00352478 0.000478745 -0.0688477 0.109863 0.0673828 0.00062561 -0.182617 0.925781 -4.125 -0.0412598 -0.00430298 0.214844 -0.152344 0.539062 -0.125977 0.102539 0.257812 -2.32812 0.029541 2.26562 2.1875 -0.84375 -0.433594 -1.47656 18.625 0.151367 -1.21094 0.0991211 1.15625 -1.02344 0.000335693 -0.230469 -0.396484 0.435547 0.235352 -0.324219 -0.175781 -0.000364304 0.871094 -0.003479 -0.0349121 1.79688 -0.578125 -7.21875 0.00576782 0.0253906 -0.0600586 -0.00262451 0.048584 -0.125 3.3125 -15.6875 -0.730469 -0.0223389 -0.188477 -0.000762939 -0.707031 -1.01562 -2.21875 -0.291016 -0.00811768 -2.10938 -0.0252686 0.765625 -0.00363159 -0.0893555 -0.408203 -0.125977 -0.034668 0.460938 2.10938 -0.308594 -0.726562 -0.124023 6.96875 1.85938 -0.0366211 0.150391 1.27344 0.090332 -1.66406 -0.136719 0.0424805 0.285156 -0.515625 -0.902344 -0.234375 0.117676 1.39844 0.208984 0.271484 -0.515625 0.0688477 0.0620117 0.0133057 0.275391 -0.0493164 0.0507812 0.0038147 -0.151367 -0.400391 -0.126953 0.000135422 -0.597656 0.000583649 -0.835938 -4.09375 -0.949219 -1.9375 -0.192383 1.14844 0.000488281 -0.201172 -0.010498 0.036377 -0.433594 -10.25 -0.0961914 -0.0402832 0.0600586 0.0708008 0.141602 -0.140625 3.04688 -1.45312 -10.9375 -0.000151634 0.585938 0.421875 0.632812 0.208008 -0.984375 0.210938 -0.241211 0.585938 -0.757812 -0.000222206 -0.443359 -0.445312 0.0952148 -0.000253677 -0.929688 0.0327148 -0.232422 0.941406 -0.000108719 0.550781 0.318359 0.388672 -2.5 1.22656 -0.00224304 -0.722656 -0.621094 0.114746 -0.332031 0.0247803 -0.527344 -0.000858307 -1.94531 -0.0228271 0.535156 -0.0771484 0.333984 -0.109375 0.5625 -0.408203 0.392578 -3.03125 0.175781 -0.351562 0.699219 0.605469 0.00105286 -0.388672 -0.00382996 1.00781 4.625 -0.0454102 -0.328125 -0.18457 -0.679688 -1.19531 -1.67188 -0.376953 0.176758 -0.441406 -0.789062 -0.679688 -1.19531 0.414062 2.23438 0.90625 -0.00738525 1.48438 0.029541 0.00756836 -0.0124512 -0.248047 0.921875 -0.0228271 -0.0177002 0.00195312 0.546875 -1.17969 -0.941406 -0.00219727 0.000473022 0.96875 0.000549316 -0.139648 2.14062 -0.478516 -0.902344 -2.67188 0.208984 1.21875 5.57899E-05 -1.14062 0.0147705 -0.00692749 -0.447266 -0.00299072 0.0180664 0.0419922 -1.69531 0.0515137 -0.0197754 -0.175781 -0.3125 1.97656 0.000839233 0.554688 -3.0625 0.726562 0.402344 -0.234375 0.988281 -0.144531 1.22656 -0.503906 -0.0292969 6.6875 0.953125 0.102051 0.339844 0.000137329 -0.0045166 0.0209961 -1.36719 1.75781 -6.84375 1.04688 -0.00262451 -0.304688 0.445312 -2.98438 0.652344 0.0177002 -0.980469 -9.6875 0.0583496 -0.617188 -0.335938 0.188477 -0.0478516 -0.0649414 -2.20312 -0.0192871 -4.78125 -0.498047 9.75 -0.466797 -0.378906 -0.0272217 -0.570312 -0.0283203 -0.105469 0.024292 -0.074707 0.0140991 0.0157471 -0.78125 0.191406 0.0568848 -0.457031 0.178711 0.166992 -0.332031 2.8125 0.0708008 0.0179443 -0.00726318 0.0400391 0.0227051 3.96875 0.00196838 -0.425781 -0.0269775 -0.605469 -0.00109863 -3.98438 -0.0324707 -0.761719 0.392578 0.574219 0.00628662 0.300781 0.0111084 -0.000227928 -0.236328 -0.675781 -0.527344 0.455078 -0.0839844 -0.265625 0.0566406 -2.53125 1.26562 8.375 -1.01562 4.15625 -0.0145874 2.01562 2.65625 -15.9375 -0.414062 -0.041748 1.36719 -0.941406 -0.00958252 0.427734 -1.45312 -0.804688 -0.451172 0.000465393 -0.527344 0.00259399 -1.10938 0.0524902 0.75 2.01562 0.106445 0.174805 0.466797 -4.9375 -1.39844 -1.35156 1.5 0.988281 7.71875 -0.259766 0.015625 -0.933594 -0.679688 0.480469 -0.289062 -0.40625 -3.375 0.625 -0.498047 -0.00124359 -0.1875 1.05469 1.08594 0.328125 -0.298828 21.5 -0.0488281 0.0786133 -0.302734 0.065918 -0.523438 0.045166 -3.5 -0.339844 -0.0751953 0.835938 -0.00233459 -0.298828 0.730469 0.198242 0.0306396 0.910156 -2.95312 0.357422 0.558594 0.0236816 0.503906 0.167969 1.46875 -0.339844 -0.118652 -0.0258789 -0.703125 -0.146484 5.65625 0.0368652 0.695312 -1.23438 0.0192871 0.00549316 0.400391 1.27344 -0.285156 0.123535 -0.882812 -2.76562 -3.96875 -0.392578 5.875 1.65625 -4.81606E-05 -0.0311279 1.65625 0.0358887 -0.00619507 -0.0839844 -0.00372314 -0.605469 -0.875 -0.000774384 -0.139648 -5.1875 1.32031 0.0466309 5.46875 0.546875 -2.71875 0.0142212 -0.00537109 -0.0175781 -0.0098877 0.796875 0.166992 -1.92969 0.211914 -0.511719 -0.00227356 -0.135742 0.120117 0.0303955 -1.13281 0.734375 0.090332 1.35938 0.523438 -0.320312 -0.000549316 4.96875 -0.00184631 -0.328125 -5.31673E-05 -0.0151367 -2.01562 -0.00267029 1.08594 0.00750732 -0.84375 -0.472656 -2.17188 0.177734 -0.328125 -0.00387573 0.00043869 -0.1875 0.0108643 -6.96875 0.273438 11.4375 -0.00320435 -0.0756836 -0.574219 -0.128906 -0.330078 -1.41406 0.789062 -0.000888824 0.0756836 -0.108398 -0.851562 -0.041748 -1.95312 0.298828 -0.351562 1.71875 1.71094 0.185547 -0.0383301 -4.625 0.0517578 0.0172119 -0.474609 -0.0703125 0.116211 -0.785156 0.121582 0.00195312 0.0219727 -0.0664062 0.00292969 -2.59375 -0.000602722 0.0683594 1.86719 0.345703 -0.00151062 0.0285645 -0.390625 -1.42969 0.4375 -0.363281 -0.953125 -0.271484 -0.000576019 0.0664062 -0.0664062 -0.00396729 0.304688 -1.07812 -0.945312 1.82031 -5.1856E-06 0.714844 -0.217773 -1.58594 2.23438 -2.10938 -0.00958252 -0.679688 0.160156 -2.03125 -0.0483398 -0.640625 -0.0332031 -0.165039 0.90625 0.198242 -0.0742188 -2.29688 -0.925781 6.75 -4.84375 0.0849609 -1.50781 -0.0118408 0.0113525 -1.61719 8.0625 0.0532227 0.165039 0.287109 0.15625 0.855469 0.314453 -1.19531 -0.0246582 -0.0893555 -0.158203 -0.511719 0.359375 -0.486328 -0.0133057 0.425781 2.39062 -0.283203 0.0025177 0.00964355 -0.285156 0.0917969 0.241211 -0.176758 -0.175781 -0.445312 -0.0142822 -0.392578 0.045166 0.640625 0.00769043 -0.245117 0.09375 0.0218506 -1.53125 -0.123535 -0.386719 -0.00122833 0.0324707 0.0583496 -0.0339355 -0.527344 -1.00781 1.375 -0.106934 -0.15332 0.00204468 -4.90625 0.416016 0.980469 -1.34375 -0.605469 -0.0280762 0.478516 -0.118652 0.208008 2.5 -0.601562 0.249023 0.000190735 0.00601196 0.337891 3.48438 1.17969 0.972656 1.00781 -0.0476074 -0.158203 -0.871094 -0.00497437 -0.482422 -0.122559 0.460938 1.20312 0.519531 0.792969 0.165039 -3.89062 11 -2.92188 0.21582 -0.171875 -1.70312 0.000276566 0.00909424 1.11719 0.172852 -0.515625 -0.283203 -0.0439453 0.00891113 -0.601562 -0.00171661 0.28125 -1.10156 -0.00939941 -0.988281 0.0197754 -1.09375 1.92188 -0.628906 0.150391 9.39369E-05 -0.439453 -0.102539 -0.00616455 -4.64916E-05 -0.0256348 0.00970459 -0.00396729 -0.253906 1.3125 0.233398 -7.5 6.34375 -0.494141 -0.135742 0.0319824 0.101562 -0.125 0.515625 0.132812 0.519531 -1.86719 -5.03125 -0.154297 0.000375748 -0.0703125 -0.203125 2.23438 -0.00190735 -0.0932617 0.00134277 -0.0373535 0.0124512 0.0198975 1.17969 0.318359 0.0649414 0.00476074 -0.984375 -0.652344 0.457031 0.0175781 0.0327148 0.730469 0.000537872 -0.800781 -0.318359 -0.013916 -1.17969 -0.365234 -4.59375 0.601562 -0.00842285 -1.60938 -1.94531 -1.14062 -0.195312 0.339844 2.4375 0.546875 0.00585938 0.0192871 -0.71875 0.496094 -0.235352 -0.558594 1.50781 0.00476074 5.125 -0.175781 0.131836 1.09375 -0.542969 -0.988281 -0.0192871 0.00610352 0.652344 -1.45312 -0.03125 0.0407715 0.400391 -0.0118408 -1.47656 -0.00173187 0.882812 -0.585938 -1.125 0.233398 0.0375977 1.57812 -2.85938 0.000324249 0.00121307 0.0446777 1 0.0461426 0.0410156 -0.0673828 1.95312 0.103027 0.0301514 -9.375 -0.00671387 0.00662231 0.00726318 -0.0240479 0.00209045 -0.910156 1.98438 -2.0625 0.036377 -2.10938 -2.625 -0.996094 -0.219727 -2.64062 -0.345703 -0.00159454 -0.0032196 -0.00680542 6.375 -0.875 0.174805 2.17188 0.167969 0.0874023 2.23438 -0.373047 0.000149727 -0.292969 -0.0415039 -0.632812 -0.289062 -0.00363159...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]

faulty_inp1.mlir

iree-compile faulty_inp1.mlir  --iree-hip-target=gfx942  --iree-hal-target-device=hip   --iree-dispatch-creation-enable-aggressive-fusion=true   --iree-global-opt-propagate-transposes=true   --iree-opt-aggressively-propagate-transposes=true   --iree-opt-data-tiling=false   --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'   --iree-hal-indirect-command-buffers=true   --iree-stream-resource-memory-model=discrete   --iree-hal-memoization=true   --iree-opt-strip-assertions -o=init1.vmfb

iree-run-module --hip_use_streams=true --device=hip://0 --function=faulty --module=init1.vmfb [email protected] [email protected]
EXEC @faulty
result[0]: hal.buffer_view
4096x4096xbf16=[1.5 -9 -2 -5.5 -1 8 -0.21875 -0.6875 -7.5 3.5 -3.75 0.4375 -1 -5 2.75 1.75 -5.5 -0.46875 -4.5 3.75 -2.25 1.875 -3.5 9 0.9375 -0.8125 -0.75 8 6 8 0.34375 -3.25 7 1.25 -6 4 3.75 1.75 9 -3.5 0.171875 0.5 3.5 -6.5 -2.75 7.5 -2.5 -1 -4 -14 -1.375 0.9375 11 0.9375 0.75 3.5 -0.00488281 -0.25 -1 5.5 -0.875 -0.6875 0.875 1.75 2.75 -0.34375 4 -4.5 -0.625 -0.5625 0.234375 -0.375 -0.0546875 14 -2 -4 0.5625 -1.25 1.25 0.46875 0.28125 6 5.5 -1.125 -0.34375 9 -0.625 -2.5 -0.5625 0.140625 5 3.25 1.25 -1.25 -0.34375 0.9375 12 -3.5 5 1.375 1 -1.375 -7 7 -0.40625 -0.3125 -1 0.15625 -0.21875 -4.5 -0.40625 -7 5.5 -1.875 5 -0.5 -3.75 1.125 -1.5 -0.5625 -1 8 0.21875 -1.25 3 0.625 -6 -5.5 -1.25 11 -0.101562 0.015625 -2.75 5 -0.25 2.25 -10 7.5 0.8125 -12 0.28125 1.125 1.5 -1 -1.875 -7 -0.46875 -2.75 -1 -0.25 -0.0351562 5.5 0.34375 -3 -3.25 -0.6875 -1.25 5.5 2.75 -3.75 -5.5 -8 -7 -4.5 -0.1875 -3.25 0.101562 1.5 -1.75 0.234375 10 -0.5625 -0.75 -0.5 -3.25 1.625 -0.625 -0.0292969 -0.1875 -2.75 -1.375 -0.6875 6 -4 0.0585938 -6.5 2.5 -0.21875 3.75 -1 -3.25 -0.46875 0.203125 -1.5 10 0.078125 -13 -0.6875 9 5 1.375 5 -0.75 9 -4 -0.109375 0.46875 7 -2.25 -10 1.25 -14 -2.5 6 -4 -8 -5.5 -1.5 0.4375 2.25 4 -1.375 -0.109375 0.15625 -0.34375 1.875 -6.5 6 -1.625 2.25 -0.375 -0.5625 9 1.25 8 0.34375 -0.8125 2 3.25 0.15625 0.28125 -0.6875 -3.5 -2.25 2 -0.5625 -3 0.5625 7 -0.046875 -3.5 -4 1.5 9 -7 -8 -5 5 -0.5625 -4 3.5 0.6875 3.25 2.75 9 -0.375 7.5 -0.6875 12 0.46875 -0.625 0.1875 12 0.203125 -0.21875 0.0703125 -12 0.5625 -5.5 -0.234375 -1.875 5.5 1.375 5.5 -1.375 -0.0703125 0.078125 -0.140625 0.5 0.15625 0.5625 3.5 0.375 -5 -0.625 -0.6875 5 -3.5 2.5 7.5 0.5625 -2.25 0.875 5.5 0.171875 -4.5 5 -0.0546875 0.75 6.5 -3.5 -3 -7.5 1.125 -11 -11 -0.25 0.9375 -4.5 1.25 -8 -3.75 2 0.3125 -2.5 -4 -2.5 -2 2 7.5 9 2 -0.9375 0.5 -11 -8 0.5625 0.6875 0.125 2.5 9 -0.8125 0.5625 -9 -0.28125 -2 -2 6 4 -7 2.5 0.40625 2.75 1 0.117188 1.625 0.0859375 3 -0.875 -0.6875 -0.234375 -0.9375 -2 2.25 3.75 -0.40625 -1.375 -7 0.3125 1.25 1.875 -0.625 -0.1875 0.203125 0.9375 -4.5 -0.5625 12 -2 -0.234375 1 3.25 -0.3125 -5 -3.75 0.9375 2.25 -0.875 0.75 1.5 -3.25 0.6875 6.5 -1.75 -1.375 2 0.28125 3.25 -0.0234375 3.75 1.25 8 -0.5 1.625 -0.8125 -0.8125 0.101562 0.375 1.875 9 0.75 -7 -7.5 -0.00878906 -8 -0.15625 -4.5 0.9375 0.40625 2 0.0078125 8 5 2.75 3 -9 0.125 -0.203125 0.625 6 -3.75 1.875 8 -0.375 -3 0.109375 -0.75 -4 -0.4375 5 0.0625 -0.015625 -0.4375 0.75 0.8125 -1 -0.101562 -0.171875 1.75 0.109375 -12 -5 9 0.109375 1.75 5.5 3.75 0.5625 -0.625 -0.078125 -0.0625 -3 -0.375 0.101562 1.25 7 -0.0703125 0.625 0.6875 -0.75 -9 0.5 2.75 3 -5.5 -2.5 -1 7.5 5.5 -2.5 2 0.4375 -4.5 7 9 15 0.34375 -0.625 0.875 5.5 6 3.25 2.25 -0.25 -0.40625 -0.6875 3.25 3.75 0.6875 8 1.5 0.5625 4 -1.625 -0.0214844 -3.25 -12 -3.75 0.00683594 7 3.5 -0.125 6 -3.5 -0.5625 -7 1 -0.00390625 -0.875 -0.203125 -0.101562 3.5 -0.6875 0.375 -0.6875 -6 2.5 2 0.3125 -6 -6 0.0703125 2.25 6.5 -2 -5.5 -0.0195312 0.625 2.25 5.5 0.171875 5.5 -0.3125 0.75 -5.5 -7.5 -0.5 4.5 1.375 -0.203125 -9 2 -1 -0.625 -0.4375 0.75 0.09375 -0.75 0.625 -5.5 -8 2.75 0.6875 -7 -3.75 -16 9 -1 -7 -0.140625 -1.5 -3.5 -3.25 -2 7 -0.078125 2.25 -0.3125 1.625 -0.1875 -3 -2.25 5.5 2.75 -0.117188 -2.75 0.21875 -1.75 -4.5 0.28125 0.6875 -6.5 -0.03125 2 -3.5 7 1.75 8 0.140625 -4 4.5 -0.46875 7 -3.25 3.25 -7 -6.5 -11 -6 -1.125 -0.03125 2.25 -7.5 -0.8125 2 -8 -1 -5.5 -0.0351562 3.5 -6 0.203125 2.75 -3.25 -1 7 -8 0.34375 0.5 1 0.0703125 -7.5 3 0.25 9 1 -0.46875 -4.5 -3.75 -4 -1 0.125 5.5 0.9375 -0.5625 -0.9375 -1.25 -5.5 -0.75 -0.203125 1.375 2.25 0.4375 -1.625 0.8125 2.75 0.101562 10 -2.25 -7.5 10 0.75 4.5 -5 1.5 -0.0703125 -3.25 -3.75 10 1.5 -0.5 -1 0.875 8 2.25 -5.5 0.15625 -10 -0.375 -8 -1.5 -0.8125 3.5 -2.25 2.5 1.25 -1.75 0.875 0.203125 0.75 0.171875 -1.25 1 -2 3.25 -1.25 0.28125 -1.125 -6 0.5 2.25 -0.9375 11 0.875 0.28125 -0.8125 -1.25 -5.5 -1.625 -5 0.5625 -1.375 -7.5 -0.5 4.5 -8 5 1.375 -0.75 -1.5 -11 -4.5 0.171875 -2.5 0.5625 9 -4 3 1.375 0.171875 -0.75 -0.375 0.15625 -6 2 1.875 -1.375 -0.28125 -7 -0.5625 -0.8125 -3.25 -0.3125 1.25 -1.625 -4 0.6875 1 2.25 12 -1.125 3.75 -6.5 0.40625 1 0.5625 2 -0.5 -7.5 7 -2.25 2.5 -0.4375 11 -3.75 -1.75 -2.75 -5.5 0.5 0.75 -5.5 -7.5 0.09375 -7.5 -0.0253906 -7.5 4.5 -4 0.375 -0.125 -0.15625 0.75 -3 0.25 -6 0.375 -1 5 0.171875 5.5 5 -0.875 1.375 0.1875 -2 2 0.0292969 0.9375 -2 1.375 5 4 -0.0078125 11 0.5625 0.5 3.75 0.75 -3.75 -2.25 -1.25 -1 1.625 6.5 2.75 -0.875 -0.5625 -2.25 0.5625 5 -1.5 0.875 0.0136719 2.25 6.5 0.625 0.3125 -1.75 -0.078125 -0.34375 2 -2.25 -0.5 2.5 0.28125 4.5 4 -3 2 1.5 -0.3125 -6 2 -0.4375 5 7.5 -0.4375 0.5625 -1.5 1.25 -0.15625 -0.46875 1 6 -1.625 -10 2 -0.101562 1.5 -0.125 -3.75 -2 5 -6 -2.75 -1.375 7 -6.5 6.5 5 4.5 -0.8125 -0.125 1.125 -4 -0.28125 -3.75 -0.0117188 8 -5 15 -5 6 -1.875 2.75 -1.5 2 -8 -9 3.5 0.117188 10 -8 -4.5 0.3125 -0.203125 -0.34375 -0.09375 0.0126953 -9 -0.625 -5.5 -2.75 -0.6875 0.21875 0.0214844 -2.5 -0.875 -4 0.4375 6 3.5 -3.25 2.25 -0.625 -2 -0.5 10 -0.140625 -1.375 0.171875 -0.25 0.3125 0.5 -4.5 -3.25 0.75 0.75 1.75 -7 0.9375 1.5 -0.28125 0.5625 -0.625 -2.75 -0.40625 -7.5 -4.5 5.5 -0.625 -0.875 0.46875 3.75 -1.875 1.5 0.000976562 0.6875 -7 -9 -0.25 -5 4 -0.0351562 1.625 -0.6875 0.28125 -0.46875 1.375 3.5 -2 0.4375 5 -1.25 0.0507812 1.375 3 4.5 -7 -1.375 -1.125 0.0625 -11 -0.4375 -0.46875 -5 9 2 -1.875 -8 7.5 0.25 -0.9375 -0.0859375 -3.5 0.4375 -1.75 -0.0195312 -0.5625 -4.5 -0.09375 0.375 -0.75 -0.3125 -4 0.15625 -1 2 -2.25 4.5 1.25 2 -0.3125 1.375 -8 -1.25 -4 1 -0.5625 1.25 3.25 3.5 3.75 -2.75 2.5 6.5 -1.25 0.875 -0.25 -0.25 -2 1.125 -1.625 -1.625...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][......][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]

faulty.mlir

iree-compile faulty.mlir  --iree-hip-target=gfx942  --iree-hal-target-device=hip   --iree-dispatch-creation-enable-aggressive-fusion=true   --iree-global-opt-propagate-transposes=true   --iree-opt-aggressively-propagate-transposes=true   --iree-opt-data-tiling=false   --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'   --iree-hal-indirect-command-buffers=true   --iree-stream-resource-memory-model=discrete   --iree-hal-memoization=true   --iree-opt-strip-assertions -o=init.vmfb

iree-run-module --hip_use_streams=true --device=hip://0 --function=faulty --module=init.vmfb [email protected] [email protected]
EXEC @faulty
result[0]: hal.buffer_view
32x4096xbf16=[-112 240 NAN NAN NAN 128 -15 -144 NAN NAN NAN -112 -160 -224 NAN 72 NAN 26 NAN NAN NAN 40 -32 NAN -160 6 88 NAN NAN NAN NAN NAN NAN NAN NAN NAN 72 -208 NAN NAN -28 26 NAN NAN NAN NAN -72 120 NAN -88 NAN -12 -72 NAN NAN -26 NAN -144 -128 NAN -18 13 NAN -56 NAN 128 -208 NAN -28 -64 13 16 -4.5 -176 NAN NAN 4.5 NAN 36 NAN -30 -72 NAN 28 -40 -192 8 NAN NAN NAN NAN NAN NAN 22 -176 NAN NAN NAN NAN -44 NAN 22 160 -208 NAN 64 44 NAN -104 NAN 60 NAN NAN NAN NAN NAN NAN NAN NAN NAN NAN NAN NAN NAN NAN -52 NAN 104 NAN -240 NAN -36 NAN NAN -104 NAN -104 NAN -64 NAN NAN 64 -208 -10 NAN NAN 208 NAN 88 48 1 NAN -224 NAN 56 -48 -104 -160 NAN -11 -240 NAN NAN -88 NAN -192 NAN NAN NAN -11 NAN -7 -80 NAN NAN 22 -14 56 -192 NAN NAN 160 NAN NAN 120 NAN NAN -44 NAN NAN NAN 36 -22 NAN 22 144 -36 -72 NAN NAN 20 NAN -26 NAN NAN NAN 10 NAN NAN NAN NAN NAN NAN NAN NAN -56 NAN NAN -2.25 -6.5 NAN 52 144 NAN -88 NAN 208 NAN NAN NAN -30 88 NAN -144 NAN -30 128 NAN NAN NAN -14 -7 -40 NAN -64 -80 NAN 240 176 13 NAN NAN 104 NAN NAN NAN NAN NAN 176 NAN NAN NAN NAN NAN NAN NAN NAN 4 NAN NAN -56 144 NAN -20 -32 NAN -160 24 NAN -9 0.40625 240 -208 NAN NAN -30 -5 -64 -40 -72 NAN NAN NAN NAN 36 -4 NAN NAN NAN NAN 6.5 NAN 60 NAN -48 240 NAN 104 7.5 NAN -3.25 -96 NAN -112 NAN NAN -12 -104 -144 48 NAN NAN NAN NAN 80 NAN 144 NAN 240 NAN NAN NAN 26 -22 NAN NAN 176 -52 -32 NAN NAN -60 32 NAN 10 NAN NAN NAN NAN NAN NAN -30 NAN 1.25 -40 NAN NAN NAN 104 -128 -176 104 -72 NAN -88 192 -26 64 NAN NAN NAN NAN NAN 11 NAN NAN 44 NAN 96 18 -20 NAN 36 NAN NAN 32 2.5 NAN NAN NAN NAN -22 -80 NAN NAN NAN -52 NAN -2 192 -40 208 40 NAN NAN NAN 128 NAN NAN NAN NAN NAN NAN 36 56 18 NAN NAN -48 NAN 160 NAN NAN NAN NAN NAN NAN NAN -240 NAN 192 NAN NAN 16 NAN NAN -20 NAN 224 -60 -18 60 -22 -224 NAN NAN -5.5 -6.5 -88 44 224 NAN NAN NAN NAN NAN NAN -128 NAN -30 22 NAN 7 52 -36 NAN NAN 44 NAN 160 NAN -20 -192 NAN NAN 240 NAN NAN 15 NAN NAN -64 NAN NAN NAN NAN -1 -16 -16 NAN 112 NAN -160 52 NAN NAN NAN NAN -56 NAN 240 22 NAN NAN 28 NAN NAN NAN -26 -240 NAN NAN NAN NAN 208 NAN -104 192 56 NAN 120 NAN NAN -72 -7 NAN NAN NAN NAN NAN NAN 22 NAN NAN 96 NAN -2 72 -72 NAN -14 NAN -16 NAN -30 NAN 2 NAN -10 0.34375 NAN NAN NAN NAN 192 104 -26 NAN NAN NAN NAN NAN -26 NAN NAN NAN NAN 44 NAN NAN 26 NAN 192 NAN NAN -2.25 NAN -32 192 104 NAN NAN NAN NAN -160 NAN NAN -240 NAN NAN NAN NAN 18 104 NAN NAN NAN NAN NAN NAN NAN 12 -160 NAN NAN 128 NAN NAN NAN 20 NAN 120 NAN -80 NAN NAN -64 32 -12 NAN NAN 40 NAN NAN NAN NAN NAN 16 -128 -72 208 144 NAN -88 240 NAN NAN -224 NAN NAN 72 5 NAN 48 16 192 -192 NAN NAN NAN NAN 128 -7.5 NAN NAN NAN 56 NAN 60 NAN NAN -52 -120 NAN NAN -60 NAN -240 NAN 128 88 -4 18 NAN -18 NAN -224 NAN -14 2 -128 -0.15625 -160 -208 NAN NAN NAN 120 -1.25 NAN -12 -96 5 36 NAN 128 NAN -36 NAN NAN 36 48 NAN -4.5 44 -72 30 NAN NAN NAN 36 NAN -192 NAN NAN NAN NAN 20 NAN 176 112 NAN -88 -9 NAN NAN NAN NAN 48 NAN -72 26 NAN -144 40 NAN 52 -18 NAN -128 -32 NAN -44 -48 NAN NAN -22 -44 -208 NAN NAN NAN NAN -128 0.46875 32 -4.5 30 NAN -208 NAN NAN 32 NAN NAN NAN NAN NAN -44 -224 NAN NAN -20 NAN -5 NAN NAN -60 36 104 NAN NAN NAN 48 -72 -18 0.9375 NAN -128 NAN NAN NAN 44 NAN -120 NAN 5.5 -160 -88 NAN -6.5 NAN -18 NAN -80 NAN 240 48 NAN -88 NAN NAN NAN 192 NAN -104 NAN 44 -16 14 56 11 NAN -48 -72 -20 -6.5 NAN 224 NAN NAN NAN 176 NAN -44 NAN -120 NAN NAN NAN -72 NAN -16 NAN NAN NAN NAN -2.75 -18 -22 NAN NAN NAN NAN NAN NAN 128 -10 NAN 80 NAN NAN NAN NAN -32 NAN NAN NAN -30 NAN NAN 32 -9 NAN NAN NAN 2.75 -120 NAN NAN NAN 20 NAN 24 NAN -80 NAN NAN -160 NAN -26 NAN NAN -104 -48 1.375 192 -5.5 NAN NAN 160 NAN NAN NAN 88 -0.40625 -176 -160 NAN 40 -22 NAN NAN NAN -40 -20 NAN NAN 18 NAN 18 -208 -14 -32 NAN NAN -10 60 NAN NAN NAN 36 NAN NAN 32 NAN NAN 176 -176 NAN NAN NAN -40 NAN NAN NAN 160 NAN NAN -26 -5.5 -96 NAN NAN NAN NAN NAN 20 NAN NAN 15 NAN NAN NAN -48 -0.9375 NAN NAN NAN -112 NAN -24 NAN -0.875 NAN NAN NAN 240 8 80 120 -8 8 52 NAN -44 -88 -104 NAN NAN 40 NAN -40 -4 -10 -60 18 NAN NAN NAN -18 NAN NAN NAN NAN NAN NAN 40 -7 -112 NAN -36 NAN NAN -20 NAN NAN -64 -28 NAN 64 NAN NAN NAN...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]

@MaheshRavishankar
Copy link
Contributor

@AmosLewis can you attach the dump from adding --mlir-print-ir-after-all --mlir-print-ir-before-all --mlir-disable-threading --mlir-print-local-scope ?

@MaheshRavishankar
Copy link
Contributor

btw, the most relevant people here are in UK. I asked @nirvedhmeshram to take a look, but will likely not get much eyes on this till Monday.

@pashu123
Copy link
Contributor

pashu123 commented Jan 31, 2025

@AmosLewis can you attach the dump from adding --mlir-print-ir-after-all --mlir-print-ir-before-all --mlir-disable-threading --mlir-print-local-scope ?

Here's the dump: https://gist.github.com/pashu123/07f94ba18756b36891828b88a56f9a55

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Jan 31, 2025

Here is a little experiment I did to convince myself that this is a overflow issue

in the matmul the truncation seems to be uncessarily going to f8 and then going back to bf16 which I believe causes the overflow

        %6 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<32x4096xf32>) outs(%2 : tensor<32x4096xbf16>) {
        ^bb0(%in: f32, %out: bf16):
          %7 = arith.truncf %in : f32 to f8E4M3FNUZ
          %8 = arith.extf %7 : f8E4M3FNUZ to bf16
          linalg.yield %8 : bf16
        } -> tensor<32x4096xbf16>

@AmosLewis
Copy link
Contributor Author

      %7 = arith.truncf %in : f32 to f8E4M3FNUZ
      %8 = arith.extf %7 : f8E4M3FNUZ to bf16

@nirvedhmeshram I just looked at the dump, this is a patten first happen after op fusion(ElementwiseOpFusionPass), any idea to work around it? Should it be fixed by add patten to fuse it into %8 = arith.truncf %in : f32 to bf16 in iree?

@nirvedhmeshram
Copy link
Contributor

      %7 = arith.truncf %in : f32 to f8E4M3FNUZ
      %8 = arith.extf %7 : f8E4M3FNUZ to bf16

@nirvedhmeshram I just looked at the dump, this is a patten first happen after op fusion(ElementwiseOpFusionPass), any idea to work around it? Should it be fixed by add patten to fuse it into %8 = arith.truncf %in : f32 to bf16 in iree?

I asked @MaheshRavishankar about it and he said we can't do anything in the compiler in a foolproof way. We need to understand why the model is doing f8 matmul and then always casting it back to bf16 and if something can be done at the torch level as we see this everywhere in the model

    %8679 = torch.aten.mm %8678, %8676 : !torch.vtensor<[1,14336],f8E4M3FNUZ>, !torch.vtensor<[14336,4096],f8E4M3FNUZ> -> !torch.vtensor<[1,4096],f8E4M3FNUZ>
    %int1_11670 = torch.constant.int 1
    %int1_11671 = torch.constant.int 1
    %int4096_11672 = torch.constant.int 4096
    %8680 = torch.prim.ListConstruct %int1_11670, %int1_11671, %int4096_11672 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %8681 = torch.aten.view %8679, %8680 : !torch.vtensor<[1,4096],f8E4M3FNUZ>, !torch.list<int> -> !torch.vtensor<[1,1,4096],f8E4M3FNUZ>
    %int15_11673 = torch.constant.int 15
    %8682 = torch.prims.convert_element_type %8681, %int15_11673 : !torch.vtensor<[1,1,4096],f8E4M3FNUZ>, !torch.int -> !torch.vtensor<[1,1,4096],bf16>

@MaheshRavishankar
Copy link
Contributor

If the value goes out of range after the truncate, there is really nothing that the compiler can do here. In effect it is maintaining program semantics cause the semantics of the program is to go out of bounds. I think the issue is with the quantization. The weights havent been quantized correctly to keep the result in range of f8 values.

@IanNod
Copy link
Contributor

IanNod commented Jan 31, 2025

I believe this model is doing bf16 for attention based on how the quantization team quantized it so likely casting up to that dtype and back for each SDPA op. @dan-garvey would know the specifics on how/why that is being done.

@dan-garvey
Copy link
Contributor

The model is quantized to work in the following way:

do some matmul in fp8. The intrinsic for fp8 matmul on mi300 outputs into fp32 -> truncate to bf16

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Feb 1, 2025

I see that's not what the torch IR snippet I showed above will do since the aten.mm has ->f8 in it. It is first truncating to f8 and then extending to bf16. I think we want aten.mm ops with(f8,f8) ->bf16, then it will do the right thing.

@dan-garvey
Copy link
Contributor

Ok, so we need a custom kernel that does that. I think they did the same thing for punet. I'll try and get new ir going over the weekend.

@dan-garvey
Copy link
Contributor

1x32x128256xbf16=[[-0.490234 -0.234375 -1.35156 2.34375 1.76562 -0.00878906 0.0220947 2.3125 0.785156 1.74219 -0.929688 -0.335938 -2.76562 1.44531 1.0625 -1.38281 -0.25 1.79688 1.65625 1.13281 2.28125 1.71875 -0.769531 1.09375 0.75 1.125 -1.47656 1.65625 0.0820312 0.507812 -0.152344 0.957031 2.125 1.14062 2.14062 3.78125 2.20312 2.48438 0.0966797 0.298828 1.42969 1.75781 1.91406 1.66406 3.26562 0.229492 2.73438 0.136719 3.3125 0.878906 1.59375 0.34375 2 0.359375 1.44531 2.29688 4.09375 2.42188 0.613281 -1.07812 -0.839844 2.28125 -0.296875 0.933594 2.57812 3.51562 2.20312 1.45312 1.01562 1.14844 -0.785156 -2.26562 0.753906 2.85938 2.125 1.28906 1.82031 2.15625 -1.21094 -1.51562 0.273438 -3.71875 -0.609375 2.1875 -1.82812 2.96875 4.03125 -0.578125 0.933594 1.60156 -3.20312 -0.597656 0.675781 -0.539062 1.38281 0.277344 5.125 -0.988281 3.17188 -2.32812 -0.289062 -3.96875 -1.29688 -0.105469 -0.515625 0.664062 2.28125 1.55469 -1.48438 -1.22656 -3.875 -7.71875 -0.151367 -1.21094 -0.0668945 0.742188 0.648438 1.14844 2.14062 -2.23438 2.375 -3.85938 2.96875 2.39062 -0.498047 -0.498047 -2.84375 2.95312 0.640625 4.5 -1.25781 1.49219 2.60938 -0.800781 1.07031 -3.82812 -0.523438 6.75 1.84375 -0.863281 -3.0625 -3 -0.617188 0.972656 -1.96875 1.35938 0.029541 -0.25 -1.10938 -4.46875 -0.910156 -1.49219 -0.255859 -1.96875 -2.3125 0.621094 1.90625 2.375 1.96875 -0.412109 -1.46875 0.12793 4.09375 6.71875 1.625 -3.10938 -1.33594 -1.85938 -3.3125 3.98438 -3.65625 0.306641 -1.46875 -1.35938 -0.894531 2.5625 -3.9375 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -0.498047 -2.96875 0.490234 -1.10938 -0.412109 -0.277344 0.648438 -1.32812 0.0488281 -1.26562 -0.710938 -1.60156 -0.308594 -1.78125 0.474609 2.1875 -0.183594 2.26562 -0.722656 -1.53125 -1.125 1.25781 -0.154297 -1.875 -0.628906 -0.382812 -3.10938 -1.67969 -0.855469 -1.58594 -4.03125 -1.75781 -2.48438 0.523438 -2.20312 -3.26562 1.66406 3.5625 -0.0344238 -1.8125 1.49219 0.296875 4.03125 2.85938 0.474609 -0.0981445 1.07812 -1.08594 0.519531 0.941406 5.4375 4.78125 2 -0.351562 3.09375 2.17188 1.78906 -1.45312 1.25 -0.78125 2.46875 1.42188 1.09375 5.34375 -2.76562 2.76562 -0.337891 0.554688 -1.76562 -0.425781 0.652344 1.5 0.154297 1.94531 1.90625 2.26562 0.660156 0.605469 2.57812 3.39062 0.796875 1.05469 1.85938 -1.19531 -1.0625 0.168945 -1.26562 -4.125 0.169922 3.17188 2.4375 2.875 2.84375 -0.832031 -3.17188 -0.223633 -0.957031 -1.75781 2.92188 0.882812 3.39062 4.875 4.15625 -2.1875 1.6875 0.185547 0.652344 -1.58594 -0.828125 1.23438 -2.07812 -0.53125 -0.449219 2.82812 -0.28125 1.75 1.69531 0.0266113 -3.76562 -0.644531 -2.6875 -3.01562 1.72656 2.21875 -1.19531 -1.82812 1.25 -4.625 -0.447266 -1.04688 1.36719 2.15625 0.6875 -0.609375 4.21875 0.890625 0.174805 0.621094 -1.46094 -2.85938 2.28125 1.29688 0.652344 -0.0883789 5.21875 -1.75781 3.89062 3 -0.269531 1.22656 -0.357422 3.25 4.90625 1.80469 -2.84375 -1.79688 0.380859 1.35156 -1.34375 2.29688 -0.890625 -3.84375 2.14062 -0.558594 -0.357422 3.84375 -0.227539 -1.625 2.4375 2.53125 -3.07812 -1.65625 6.71875 3.71875 -0.232422 1.13281 -0.753906 -0.949219 0.345703 0.357422 2 1.66406 0.0717773 5.0625 0.742188 -0.640625 -1.57812 -1.85156 0.357422 1.28125 -0.878906 2.125 1.25 -3.70312 0.585938 -0.498047 0.75 -1.25781 1.41406 0.683594 0.554688 -0.589844 -2.10938 -1.59375 1.36719 -1.28906 -1.45312 -0.769531 -2.10938 1.54688 -0.753906 1.39844 -1.65625 -0.308594 -1.63281 3.28125 1.09375 0.570312 -1.77344 -1.76562 0.349609 1.92969 -3.75 -2.45312 -1.27344 1.5625 -0.824219 -0.867188 0.660156 5.6875 -0.414062 2.96875 -0.333984 0.890625 0.277344 1.75 3.6875 -0.192383 1.86719 0.960938 1.82812 0.570312 -2.14062 1.97656 1.45312 1.65625 -0.796875 0.636719 2.46875 -2.35938 1.45312 1.91406 1.82031 -1.33594 -5.84375 -0.601562 0.527344 -2.32812 1.26562 1.25781 1.5625 -2.9375 -0.648438 0.119629 5.90625 -0.988281 -3.6875 -1.58594 4.5625 -0.351562 0.980469 -2.6875 0.789062 -0.917969 -1.08594 -0.925781 0.933594 -0.048584 1.74219 1.97656 0.4375 0.0517578 2.32812 -2.64062 -0.160156 0.519531 -1.77344 -4.0625 0.236328 2.28125 -0.535156 1.875 3.17188 -0.351562 3.67188 -3.4375 -2.625 -2.5625 1.86719 0.431641 3.23438 -1.92188 -2.34375 0.703125 1.63281 -1.42969 0.404297 0.84375 -2.51562 -2.39062 -0.292969 0.824219 -1.03125 -2.8125 0.648438 -2.0625 0.19043 3.125 -2.07812 0.0400391 0.597656 2.21875 -1.16406 -2.5625 -4.90625 0.429688 -1.63281 -1.02344 0.199219 -1.44531 -2.28125 -0.134766 -0.216797 0.0512695 1.41406 3.60938 0.263672 -0.212891 -0.333984 0.324219 -0.949219 -0.458984 0.267578 0.236328 3.20312 -1.42188 5.8125 2.34375 -3.125 -3.57812 -5.0625 -2.375 2.09375 -2 -0.015564 -4.65625 2.85938 -1.46875 0.542969 2.70312 -1 0.675781 -0.0305176 -1.34375 -2.25 -0.765625 -1.67188 -0.310547 -0.734375 -0.90625 -0.335938 -0.535156 3.01562 -3.23438 -0.980469 -4.3125 4.40625 2.09375 -0.691406 2.67188 -1.79688 -0.482422 2.23438 1.85156 -1.23438 -3.65625 -1.71875 -1.08594 0.523438 -0.0258789 -2 1.42188 -1.35156 -2.89062 0.726562 6.09375 2.28125 -0.277344 -1.86719 2.5 2.48438 2.17188 0.710938 -4.21875 -3.25 -0.316406 1.89844 -2.90625 1.03125 -0.808594 -2.65625 -1.39062 -1.42969 -1.71875 0.941406 -0.71875 -0.380859 3.79688 0.730469 1.34375 -1.53906 -1.73438 0.867188 -0.796875 -1.78906 -1.07812 0.859375 0.859375 -0.478516 0.478516 -1.0625 1.14062 1.28125 4.125 -6.28125 1.71094 -1.02344 0.045166 3.3125 -1.57031 0.824219 -3.65625 -0.388672 -1.88281 0.617188 -2.26562 1.67969 5.875 1.28906 -2.60938 1.25781 -2.95312 0.384766 1.08594 -1.17188 -1.24219 0.416016 1.84375 4.3125 -0.251953 4.40625 3.5 -2.07812 -2.15625 -0.0181885 -0.197266 1.49219 3.0625 -2.5 -2.51562 -0.816406 -3.39062 -1.82812 -1.27344 4 0.78125 -0.808594 3.23438 -1.23438 -5.40625 1.83594 0.451172 2.84375 -2.28125 -1.57812 -1.40625 0.558594 -4.09375 3.26562 -2.78125 -3.64062 2.5 0.742188 4.625 -2.71875 0.734375 -0.204102 -2.85938 -0.255859 -0.00958252 1.8125 -0.734375 1.82812 -2.84375 0.871094 2.26562 2.1875 -1.35156 -2.78125 1.22656 0.960938 -3.53125 -2.39062 -1.21094 1.97656 -0.0175781 1.58594 0.582031 2.46875 -3.40625 2.14062 0.027832 1.64844 0.660156 -2.60938 0.492188 -0.695312 0.566406 -0.792969 0.112305 2.15625 3.4375 -1.29688 -2 0.351562 3.17188 1.8125 -2.42188 0.169922 0.503906 2.21875 4.09375 -0.699219 -0.234375 -0.902344 1.88281 -4.25 -0.243164 0.182617 0.800781 0.427734 1.42969 1.14062 0.496094 2.67188 -1.36719 -0.458984 0.197266 -0.347656 1.07031 2.96875 1.21094 2.98438 3.23438 2.9375 0.65625 -0.00726318 -0.671875 1.15625 -0.419922 -1.16406 -3.04688 1.1875 -1.78125 -0.225586 0.396484 1.60156 -2.34375 -2.76562 1.26562 -0.625 0.0524902 -1.25 -1.03906 2.04688 -2.26562 1.875 -0.285156 1.59375 -0.423828 0.578125 0.0698242 0.90625 3.375 1.10156 2.57812 1.89062 2.92188 -0.25 -1.83594 -1.30469 -2.375 -0.238281 -3.4375 2.45312 0.980469 1.26562 0.106934 -1.67188 4.09375 0.152344 -1.05469 -3.85938 -1.58594 -1.28125 1.67969 0.792969 -2.71875 -0.361328 1.49219 -0.443359 -1.90625 1.88281 -0.957031 2.53125 4.0625 -1.23438 0.158203 1.61719 2.85938 3.48438 0.914062 -0.636719 1.10938 3.25 -3.48438 -2.125 -1.6875 2.70312 1.99219 -0.361328 -4 -1.04688 0.122559 4.15625 0.196289 1.64062 -1.4375 -0.178711 0.65625 2.125 0.503906 1.28906 1.15625 0.470703 -0.353516 -0.683594 0.847656 1.50781 -1.94531 -0.251953 -2.59375 -0.0161133 1.71875 -1.34375 3 -2.125 3.84375 -2.23438 0.671875 -0.582031 1.22656 -2.5 1.6875 3.54688 0.339844 -0.601562 -2.875 -0.194336 -3.32812 -0.227539 1.95312 -0.279297 -2.23438 -1.30469 3.21875 1.51562 3.34375 -1.61719 -5.6875 2.84375 0.361328 -0.494141 2.40625 0.3125 0.898438 0.194336 -0.625 1.32812 1.14062 -2.89062 2.3125 -0.511719 1.52344 -0.96875 -0.859375 -3.39062 -1.34375 -1.78125 -1.00781 -3.20312 3.75 0.283203 0.738281 1.07812 -1.46875 1.04688 1.24219 0.652344 1.38281 -0.511719 0.386719 1.48438 -1.15625 -0.804688 3.09375 1.53906 2.39062 -0.078125 -1.09375 0.933594 1.47656 -1.39844 2.32812 2.70312 -2.20312 -2.96875 -1.58594 0.429688 1.13281 -2.90625 0.632812 -0.435547 -3.09375 -0.867188 3.92188 0.361328 0.361328 0.462891 -3.375 0.223633 0.699219 5.59375 -1.48438 -1.58594 -0.535156 0.578125 -4.90625 -1.85938 3.42188 0.416016 -0.137695 -3.42188 -1.82812 -1.73438 -0.128906 -0.0235596 -3.45312 -2.57812 2.42188 -5.875 -2.04688 1.03125 2.35938 1.5 3.20312 -0.804688 -3.76562 0.785156 2.51562 -0.24707 -0.335938 0.660156 -0.00964355 -1.28906 0.769531 -1.22656 -3.95312 2.54688 -0.757812 -2.82812 -0.341797 0.404297 3.34375 -2.875 2.40625 2.40625 2.25 6.09375 3.28125 0.318359 -2.32812 0.322266 -2.15625 2.65625 -0.335938 4.5 0.151367 0.365234 0.0708008 1.35938 3.70312 2.375 -0.114258 0.953125 -1.91406 1.99219 -0.423828 1.6875 1.29688 -0.777344 -4.84375 -4.3125 -3.10938 -1.79688 -0.679688 0.216797 1.21875 -0.0756836 -1.66406 -0.921875 0.408203 0.101562 2.03125 0.667969 -0.671875 -0.945312...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]]

nod-ai/shark-ai#896
have to find a way to clean that up, but in the meantime this exports something functional to continue checking numerics @AmosLewis

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Feb 2, 2025

nod-ai/shark-ai#896 have to find a way to clean that up, but in the meantime this exports something functional to continue checking numerics @AmosLewis

@dan-garvey It bring new issues. Indexput's input and value to be put should be same type, but you change make them different, input b16 but value f8.
fp8_dan.mlir

 /home/chi/src/iree-build/tools/iree-compile fp8_dan.mlir \
  --iree-hip-target=gfx942 \
  -o=fp8_dan.vmfb \
  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions
fp8_dan.mlir:2018:12: error: failed to legalize operation 'torch.aten.index_put.hacked_twin' that was explicitly marked illegal
    %678 = torch.aten.index_put %667, %677, %674, %false_143 : !torch.vtensor<[?,32,8,128],f16>, !torch.list<optional<vtensor>>, !torch.vtensor<[?,32,8,128],f8E4M3FNUZ>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16>
           ^
fp8_dan.mlir:2018:12: note: see current operation: %1317 = "torch.aten.index_put.hacked_twin"(%1303, %1316, %1312, %51) : (!torch.vtensor<[?,32,8,128],f16>, !torch.list<vtensor>, !torch.vtensor<[?,32,8,128],f8E4M3FNUZ>, !torch.bool) -> !torch.vtensor<[?,32,8,128],f16>
fp8_dan.mlir:23240:12: error: failed to legalize operation 'torch.aten.index_put.hacked_twin' that was explicitly marked illegal
    %729 = torch.aten.index_put %714, %728, %708, %false_118 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list<optional<vtensor>>, !torch.vtensor<[1,1,8,128],f8E4M3FNUZ>, !torch.bool -> !torch.vtensor<[?,32,2,32,8,128],f16>
           ^
fp8_dan.mlir:23240:12: note: see current operation: %1300 = "torch.aten.index_put.hacked_twin"(%1289, %1299, %1282, %68) : (!torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list<vtensor>, !torch.vtensor<[1,1,8,128],f8E4M3FNUZ>, !torch.bool) -> !torch.vtensor<[?,32,2,32,8,128],f16>

indexput.torch.mlir
torch-mlir-opt --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg --convert-torch-to-tmtensor indexput.torch.mlir > indexput.linalg.mlir --debug

Trying to match "(anonymous namespace)::ConvertAtenIndexPutHackedTwinOp"
    ** Failure : Input element type should be same as the values element type.

Besides, for the torch.aten.mm, only the dim changed from static to dynamic, I didn't see and type change.
torch.aten.mm %0, %1 : !torch.vtensor<[32,4096],f8E4M3FNUZ>, !torch.vtensor<[4096,4096],f8E4M3FNUZ> -> !torch.vtensor<[32,4096],f8E4M3FNUZ>
->
torch.aten.mm %568, %565 : !torch.vtensor<[?,4096],f8E4M3FNUZ>, !torch.vtensor<[4096,4096],f8E4M3FNUZ> -> !torch.vtensor<[?,4096],f8E4M3FNUZ>

@AmosLewis
Copy link
Contributor Author

The numeric is fixed with Dan local shark-ai new generated llama_8b_f8.mlir, . The generated llama_8b_f8.vmfb, llama_8b_f8_prefill.tracy, llama_8b_f8_decode.tracy. BTW the nod-ai/shark-ai#896 current fix is not enough for now(a98a332). I copy the mlir file from Dan's local directory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants