Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] cherry pick fp8 attn nonsense with hack cream #907

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

dan-garvey
Copy link
Member

@dan-garvey dan-garvey commented Feb 4, 2025

 python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/home/chi/src/test/llama/dan/fp8_attn.irpa \
--output-mlir=/home/chi/src/test/llama/dan/f8_attn_chi_castf32_roctorch.mlir \
--output-config=/home/chi/src/test/llama/dan/config_attn_chi.json \
--bs=1 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 --use-attention-mask --use-hf

sudo cp /home/dan/SHARK-Platform/fp8_attn.irpa fp8_attn.irpa

@AmosLewis
Copy link

AmosLewis commented Feb 19, 2025

Default bs1_input_32 iree compile and runs well without NAN. But I got iree-compile bug for bs4_input128. Should I file a iree-compile issue or it could be fixed here?
llama_fp8_attn8_bs4_128_bug.txt

/sharedfile/attn/128/fp8_attn.mlir:29732:13: error: 'util.call' op function type mismatch; expected '(tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>' but callee is '(tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>'
    %2032 = "util.call"(%2026, %2027, %2028, %2031, %2030) <{callee = @sharktank_masked_flash_attention_4_32_128_128_f8E4M3FNUZ_f32_f32}> : (tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>
            ^
/sharedfile/attn/128/fp8_attn.mlir:29732:13: note: see current operation: %1905 = "util.call"(%1899, %1900, %1901, %1904, %1903) <{callee = @sharktank_masked_flash_attention_4_32_128_128_f8E4M3FNUZ_f32_f32}> : (tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>

The exported mlir for bs4 llama_fp8_attn8_bs4_input128.mlir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants