Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: QWEN2 quantization GGML_ASSERT #7805

Closed
bartowski1182 opened this issue Jun 6, 2024 · 74 comments
Closed

Bug: QWEN2 quantization GGML_ASSERT #7805

bartowski1182 opened this issue Jun 6, 2024 · 74 comments
Labels
bug-unconfirmed high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow) stale

Comments

@bartowski1182
Copy link
Contributor

bartowski1182 commented Jun 6, 2024

What happened?

When attempting to quantize Qwen2 7B instruct to IQ2_XS I get the following assert:

GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0

Anything I can provide to debug? Uploading the f32 file and imatrix now for recreation

Attempting IQ2_S now, will update if it fails in the same way update: it fails in the same way on the same block

Name and Version

Version b3086, ubuntu 22.04

What operating system are you seeing the problem on?

Linux

Relevant log output

[ 327/ 339]              blk.27.attn_norm.weight - [ 3584,     1,     1,     1], type =    f32, size =    0.014 MB
[ 328/ 339]               blk.27.ffn_down.weight - [18944,  3584,     1,     1], type =    f32, converting to iq2_xs .. GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0
GGML_ASSERT: ggml-quants.c:12083: grid_index >= 0

(PS: is this a high severity or medium/low?)

@bartowski1182 bartowski1182 added bug-unconfirmed high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow) labels Jun 6, 2024
@bartowski1182
Copy link
Contributor Author

bartowski1182 commented Jun 6, 2024

Interestingly when I try Q2_K (and any other K quant it seems, tried Q4_K_S to check) I see a different error:

[ 329/ 339]               blk.27.ffn_gate.weight - [ 3584, 18944,     1,     1], type =    f32, converting to q2_K .. ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0
ggml_validate_row_data: found nan value at block 0

(repeats several times)

@bartowski1182
Copy link
Contributor Author

Files for recreating it locally are up here: https://huggingface.co/bartowski/Qwen2-7B-Instruct-GGUF/tree/main

@SabinStargem
Copy link

SabinStargem commented Jun 6, 2024

If 'Quill' is indeed the leaked version of Qwen2, then something might have changed in llama.cpp that broke Qwen2 conversions into GGUF. I am using a Quill GGUF from Mradermacher's repository about six or so days ago, and it works fine. This assumes that the official release of Qwen2 isn't altered from Quill.

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

I think that the imatrix contains nan values. Did it complete without nan ppl?

@bartowski1182
Copy link
Contributor Author

@slaren are you referring to the --no-ppl option by chance or did you mean something else?

when I remove the --no-ppl the imatrix output does look like this:

llama_cpp_gpu-1 | [1]nan,[2]nan,[3]nan,[4]nan,[5]nan,[6]nan,[7]nan,[8]nan,[9]nan,
llama_cpp_gpu-1 | save_imatrix: stored collected data after 10 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix
llama_cpp_gpu-1 | [10]nan,[11]nan,[12]nan,[13]nan,[14]nan,[15]nan,[16]nan,[17]nan,[18]nan,[19]nan,
llama_cpp_gpu-1 | save_imatrix: stored collected data after 20 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix
llama_cpp_gpu-1 | [20]nan,[21]nan,[22]nan,[23]nan,[24]nan,[25]nan,[26]nan,[27]nan,[28]nan,[29]nan,
llama_cpp_gpu-1 | save_imatrix: stored collected data after 30 chunks in /models/Qwen2-7B-Instruct-GGUF/Qwen2-7B-Instruct.imatrix

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

Yes. If you ever get nan in these values, the imatrix is pretty much guaranteed to be useless. In fact the model is completely unusable as long as that happens.

@bartowski1182
Copy link
Contributor Author

yikes.. so what the hell am i doing wrong that's not affecting others (or is it affecting others but it's just not obvious cause they didn't calculate imatrix?)

@bartowski1182
Copy link
Contributor Author

I even tried updating to master and using different datasets

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

If it happens while using CUDA, it usually means that the model is producing activations that cannot be represented in a float16. We have workarounds to force float32 compute in these cases for the known models that need it, but it is very unusual.

@maziyarpanahi
Copy link

I do imatrix for Qwen2 models on pure CPU, no CUDA build:
image

@bartowski1182
Copy link
Contributor Author

ooo so that's gotta be it then.. strangely i converted to f32 instead of f16, but i assume that's not enough?

trying to run my F32 conversion with CUDA gives pure garbage

running it with -ngl 0 gives actual results

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

It may work with flash attention (-fa).

@slaren
Copy link
Collaborator

slaren commented Jun 6, 2024

The problem seems to be in the KV. This patch should allow it to work:

diff --git a/llama.cpp b/llama.cpp
index 32264a00..6324af70 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -7009,7 +7009,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);

-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -15376,6 +15376,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, c

However it also seems to work with flash attn even without GGML_PREC_F32. @JohannesGaessler do you know what that may be? is the fattn implementation less susceptible to intermediate values out of range of f16?

@bartowski1182
Copy link
Contributor Author

As always super appreciative of the investigation and solve @slaren , you're awesome <3

@JohannesGaessler
Copy link
Collaborator

is the fattn implementation less susceptible to intermediate values out of range of f16?

Assuming the problem is overflows rather than underflows I think the most likely reason is that my FlashAttention implementations initially scale Q with 1.0f/sqrtf(float(n_embd_head)) (because it's faster) while regular attention does the KQ matrix multiplication first and applies the scale to the result. But this factor is not very large so there is a good chance that for some specific inputs you would still get overflows.

@bartowski1182
Copy link
Contributor Author

@slaren would ROCm also be experiencing this bug or is it just in CUDA code? Curious if other odd bugs out there can be explained by this

@dillfrescott
Copy link

Heres my current output for qwen2. llama.cpp seems to be utterly broken right now.

Screenshot 2024-06-06 224438

Here is my run command:

main -n -1 -m "qwen2-7b-instruct-q8_0.gguf" ^
--color --multiline-input -t 16 -ngl "900" ^
--log-disable -c 8192 --temp 0.6 --interactive-first ^
--interactive

@bartowski1182
Copy link
Contributor Author

@dillfrescott can you try adding -fa and seeing if that helps?

@dillfrescott
Copy link

@bartowski1182 Damn, that totally fixed it! What does that flag do??

@bartowski1182
Copy link
Contributor Author

Enables flash attention, apparently it runs the operations in a different order that prevents the issues

@dillfrescott
Copy link

Interesting

@JustinLin610
Copy link
Contributor

Heres my current output for qwen2. llama.cpp seems to be utterly broken right now.

Screenshot 2024-06-06 224438

Here is my run command:

main -n -1 -m "qwen2-7b-instruct-q8_0.gguf" ^
--color --multiline-input -t 16 -ngl "900" ^
--log-disable -c 8192 --temp 0.6 --interactive-first ^
--interactive
./main -m qwen2-7b-instruct-q5_k_m.gguf -n 512 --color -i -cml -f prompts/chat-with-qwen.txt

This is what I do. And are you using GPU and it might be a CUDA issue? Not sure. This thing also happens in Ollama for some users, but my friend inputs /set num_gpu 0 and finds it work. Things are a bit strange for me...

@dillfrescott
Copy link

@JustinLin610 I am using a 4090 and have cuda version 12.5 installed. It usually works with every other model so I'm just not sure why.

@mann1x
Copy link

mann1x commented Jun 7, 2024

@dillfrescott can you try adding -fa and seeing if that helps?

Any reason why enabling flash_attn for the imatrix calculation the time to complete increased from about 2.5h to 5h (qwen2-72b)?
I thought I'd be faster... does it mean that without fa the imatrix was not computed properly?

@JohannesGaessler
Copy link
Collaborator

@slaren would ROCm also be experiencing this bug or is it just in CUDA code? Curious if other odd bugs out there can be explained by this

Not him but the fundamental issue is that the KQ matrix has values outside the FP16 range. On Volta or newer or RX 7000 or newer llama.cpp does all matrix multiplications with a batch size > 64 using FP16 cuBLAS/HIPBLAS so you get garbage results. The fix works by instead calculating the KQ matrix with FP32 precision. The huge values from the KQ matrix are then fed to softmax which brings them into the 0-1 range again and fixes the numerical issues.

You should be able to fix the numerical issues with better performance on Ampere or newer by calculating the KQ matrix as BF16 instead of FP16 (with some precision loss which is probably negligible if you're going to apply softmax anyways). Annoyingly there is to my knowledge no way to make cuBLAS return the result as FP32 (even though that is the only output format that tensor cores actually support) because then we would also save some time on the conversion back to FP32.

@JohannesGaessler
Copy link
Collaborator

Any reason why enabling flash_attn for the imatrix calculation the time to complete increased from about 2.5h to 5h (qwen2-72b)?
I thought I'd be faster... does it mean that without fa the imatrix was not computed properly?

Depends on what hardware you have, on AMD large batch FA performance is bad. I have received some reports about bad performance with partial offloading but so far I have never been able to reproduce this.

And note that I think that just adding -fa is not a reliable fix so there is a good chance that you will at some point still get NaNs.

@mann1x
Copy link

mann1x commented Jun 7, 2024

Depends on what hardware you have, on AMD large batch FA performance is bad. I have received some reports about bad performance with partial offloading but so far I have never been able to reproduce this.

for normal inference works fine fa, it's running on an RTX3090 on Linux so 12 layers only on VRAM
I never got any NaN, either with or without fa

With -fa:

llama_new_context_with_model: n_batch    = 204
llama_new_context_with_model: n_ubatch   = 204
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    68.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =    12.00 MiB
llama_new_context_with_model: KV self size  =   80.00 MiB, K (f16):   40.00 MiB, V (f16):   40.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.58 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  2500.71 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     6.60 MiB
llama_new_context_with_model: graph nodes  = 2487
llama_new_context_with_model: graph splits = 956

system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 117.352 ms
compute_imatrix: computing over 100 chunks with batch_size 204
ggml_backend_cuda_graph_compute: CUDA graph update failed
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to too many consecutive updates
ggml_backend_cuda_graph_compute: CUDA graph update failed
compute_imatrix: 177.86 seconds per pass - ETA 4 hours 56.43 minutes

Without -fa:

llama_new_context_with_model: n_ctx      = 224
llama_new_context_with_model: n_batch    = 204
llama_new_context_with_model: n_ubatch   = 204
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =    59.50 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =    10.50 MiB
llama_new_context_with_model: KV self size  =   70.00 MiB, K (f16):   35.00 MiB, V (f16):   35.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.58 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  2500.71 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     6.57 MiB
llama_new_context_with_model: graph nodes  = 2806
llama_new_context_with_model: graph splits = 956

system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0
 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 119.284 ms
compute_imatrix: computing over 100 chunks with batch_size 204
ggml_backend_cuda_graph_compute: CUDA graph update failed
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to too many consecutive updates
ggml_backend_cuda_graph_compute: CUDA graph update failed
compute_imatrix: 98.19 seconds per pass - ETA 2 hours 43.65 minutes

@grapevine-AI
Copy link

Hello.
I found that you make Q6_K first and then you re-quantize, you can make i-quants with i-matrix.
Q6_K is extremely high-quality so I think this method probably realize faultless quants.
In fact, my i-quants seem to be work correctly.
Would you please help to verify.

@mann1x
Copy link

mann1x commented Jun 13, 2024

Would you please help to verify.

Do you have your quants uploaded somewhere?

@grapevine-AI
Copy link

My example is here.
However, please note that I use both English and Japanese text because I am not sure what data English speakers typically use.

@bartowski1182
Copy link
Contributor Author

bartowski1182 commented Jun 13, 2024

ignore, it's because i have a p100, bad report

hmm, this is odd..

Trying to run a Qwen2 quant (https://huggingface.co/bartowski/Tess-v2.5-Qwen2-72B-GGUF/blob/main/Tess-v2.5-Qwen2-72B-Q2_K.gguf) with GPU offloading yields a new assert:

GGML_ASSERT: ggml-cuda/dmmv.cu:665: false

based on the assert, my guess is that it's because ffn_down.weight was quantized to IQ4_NL? but obviously i'm not positive. Not offloading works fine, redownloading Q4_K_M to test since i see that didn't use IQ4_NL

@slaren any idea why this wasn't an issue in the past? Is there something special in Qwen2 that makes it want to use those quant types? Also it's strange because I thought IQ quants were fully supported on CUDA so i don't get why those asserts exist

edit: as suspected, no issues with Q4_K_M

@bartowski1182
Copy link
Contributor Author

also gonna pull @JohannesGaessler back in since he seems to know the area very well

@JohannesGaessler
Copy link
Collaborator

GGML_ASSERT: ggml-cuda/dmmv.cu:665: false

If you are on a P100 or Maxwell or older that is expected. For all quants other than the legacy quants and k-quants there is only a MMVQ implementation (which needs Pascal != P100 or newer) but no DMMV implementation. If you are using a GPU that should be supported that is a bug.

@bartowski1182
Copy link
Contributor Author

ah dammit thank you, that's silly of me, i have a 3090 and a p100 and this pushed it onto that card

You say p100 or maxwell/older, does that imply the p40 is fine?

@slaren
Copy link
Collaborator

slaren commented Jun 13, 2024

Is there something special in Qwen2 that makes it want to use those quant types?

Most k and iq quants have a block size of 256, but this ffn_down in this model has a dimension that is not divisible by 256. IQ4_NL is used as a fallback for IQ4_XS and lower, since it has a block size of 32. For higher quants Q5_x or Q8_0 is used as the fallback instead, which is compatible with this GPU.

@JohannesGaessler
Copy link
Collaborator

You say p100 or maxwell/older, does that imply the p40 is fine?

Yes, P100s have compute capability 6.0, all other Pascal GPUs (including P40s) have compute capability 6.1 which is the minimum CC for the __dp4a instruction (per-byte dot product).

@bartowski1182
Copy link
Contributor Author

thanks both slaren and Johannes, that's good info, i appreciate it :D

back on the original subject of this issue, is there a reason your proposed change wasn't opened as a PR? I can open one if that would help, just unsure if the proposed fix was deemed not appropriate

@mann1x
Copy link

mann1x commented Jun 13, 2024

Would you please help to verify.

I'm testing the IQ3_XXS and seems to work very well.
Tried as well to create the imatrix from the Q6_K but it didn't work, always NaN

@slaren
Copy link
Collaborator

slaren commented Jun 13, 2024

is there a reason your proposed change wasn't opened as a PR?

Do you mean the fix to use fp32 precision for the attention? I didn't open a PR because the fix would affect all models with qwen2 architecture, and I recall reading that this model has other issues, and if this model is not very good, it may not be worth it to apply a patch that will decrease the performance of all the models with qwen2 architecture. But I may be wrong about that, if you think that it is worth it the fix could be merged.

@bartowski1182
Copy link
Contributor Author

The thing is we now have a lot of fine tunes of Qwen2 72b coming out, and presumably they all have this issue (I haven't re-verified) so figured it would make sense, but maybe it is worth double checking, didn't realize there would be a big performance hit

@grapevine-AI
Copy link

I'm testing the IQ3_XXS and seems to work very well. Tried as well to create the imatrix from the Q6_K but it didn't work, always NaN

Thanks for testing.
Sorry, I simplified explanation too much so I'll try to make it clear.
I have shown detail on my HF's README.

If anyone is interested in it, I would greatly appreciate it if you could attempt to reproduce the steps.

@mann1x
Copy link

mann1x commented Jun 14, 2024

If anyone is interested in it, I would greatly appreciate it if you could attempt to reproduce the steps.

Thanks!
Still have the f32, will try again.

@mann1x
Copy link

mann1x commented Jun 14, 2024

If anyone is interested in it, I would greatly appreciate it if you could attempt to reproduce the steps.

Tried creating the imatrix from Q8_0 but it came out same as with f32, I get NaN when quantising

@grapevine-AI
Copy link

Tried creating the imatrix from Q8_0 but it came out same as with f32, I get NaN when quantising

Thank you!

Oh, no. I wonder if this method is depend on imatrix-dataset...
I'll research other language text data.

@CISC
Copy link
Contributor

CISC commented Jun 16, 2024

@grapevine-AI Your imatrix is corrupt BTW:

imatrix entry "blk.20.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.18.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.11.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.15.ffn_down_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.13.ffn_down_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.3.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.12.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.12.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.21.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.9.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.13.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.7.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.0.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.27.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.13.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.2.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.4.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.12.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.1.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.6.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.14.ffn_down_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.11.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.16.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.17.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.14.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.19.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.16.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.5.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.11.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.15.ffn_up_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.17.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.10.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.14.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.8.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.15.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.1.ffn_gate_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.1.ffn_down_exps.weight" contains non-normal value -nan, skipping!
imatrix entry "blk.16.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!
imatrix entry "blk.17.ffn_down_exps.weight" contains non-normal value 0.000000, skipping!

I have made some progress though, by quantizing the BF16 to F16 and flushing anything <±1e-24 to ±0 I can now make fully working quants (with #7825 and #7955 applied) that do not require Flash Attention! However creating a fully activated imatrix still does not seem feasible...

@flastir
Copy link

flastir commented Jul 4, 2024

It appears that the Qwen2-72B model stopped functioning correctly after release b3091.

Result in the release b3091:

User: who are you?
Llama: I am Llama, a friendly and helpful chatbot designed to assist you with any questions or tasks you might have. My goal is to make your day easier by providing accurate information and engaging in meaningful conversations.

Result in the next release b3130:

User: who are you?
Llama: !t 10.
The- .4s híst, and of
A) ()/b0–all.
...

In the latest release that I checked b3291 the model still doesn't work correctly.

The GGUF model was downloaded from here.

I started the server from llama-b3091-bin-win-cuda-cu12.2.0-x64.zip with this command:

server.exe -m Qwen2-72B-Instuct-Q5_K_M-00001-of-00002.gguf -c 2048 -ngl 0 -fa

The GGUF works normally in KoboldCpp v1.69.1 only if "Use FlashAttention" is checked and "Use QuantMatMul (mmq)" is unchecked.

Is there an argument to server.exe that disables MMQ?

@JohannesGaessler
Copy link
Collaborator

The GGUF works normally in KoboldCpp v1.69.1 only if "Use FlashAttention" is checked and "Use QuantMatMul (mmq)" is unchecked.

Is there an argument to server.exe that disables MMQ?

If at all possible, please try to reproduce the issue using only llama.cpp code and use git bisect to identify the exact commit that introduced the problem. You can disable MMQ by compiling with GGML_CUDA_FORCE_CUBLAS.

Also this very much sounds like a different problem than what was discussed previously here. Instead of commenting on an existing issue, please open a new issue instead. Otherwise there is a large risk that your issue will not get attention from the right people.

@flastir
Copy link

flastir commented Jul 4, 2024

Also this very much sounds like a different problem than what was discussed previously here. Instead of commenting on an existing issue, please open a new issue instead. Otherwise there is a large risk that your issue will not get attention from the right people.

Thank you! I've found the appropriate issue for my problem: Issue #8025

@LostRuins
Copy link
Collaborator

Necro-ing this thread, but just chipping in, IQ4_NL is probably not an ideal fallback as it breaks compatibility with any backend that does not have I-Quant support (e.g. Vulkan), and leads to people unsure why certain K-quants just don't work when they're supposed to.

I'd rather Q4_0 or Q8_0 be the default fallback for k-quants. This is just my 2 cents.

@ggerganov
Copy link
Owner

IQ4_NL is probably not an ideal fallback as it breaks compatibility with any backend that does not have I-Quant support (e.g. Vulkan), and leads to people unsure why certain K-quants just don't work when they're supposed to.

Hm, it shouldn't break compatibility - the current implementation will fallback to CPU computation when the backend does not support a certain type. It would be slow, but it would still work

@0cc4m 0cc4m mentioned this issue Jul 21, 2024
4 tasks
@github-actions github-actions bot added the stale label Aug 18, 2024
Copy link
Contributor

github-actions bot commented Sep 1, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed high severity Used to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow) stale
Projects
None yet
Development

No branches or pull requests