Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSeek-Coder-V2-Lite-Instruct not working when quantized to FP8 using AutoFP8 #29

Closed
Syst3m1cAn0maly opened this issue Jul 10, 2024 · 10 comments · Fixed by vllm-project/vllm#6417

Comments

@Syst3m1cAn0maly
Copy link

Hi !

I quantized DeepSeek-Coder-V2-Lite-Instruct to FP8 using AutoFP8 but when I try to run it with vLLM I get the following error :

RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

I ran the quantization using this script :

from datasets import load_dataset
from transformers import AutoTokenizer
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "/path-to-models/DeepSeek-Coder-V2-Lite-Instruct"
quantized_model_dir = "/path-to-models/DeepSeek-Coder-V2-Lite-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

# Load and tokenize 512 dataset samples for calibration of activation scales
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")

# Load the model, quantize, and save checkpoint
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(examples)
model.save_quantized(quantized_model_dir)

and I got the following output :

Quantizing weights: 100%|██████████| 8809/8809 [00:00<00:00, 11539.48it/s]
Calibrating activation scales: 100%|██████████| 512/512 [11:04<00:00,  1.30s/it]
DeepseekV2ForCausalLM(
  (model): DeepseekV2Model(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekV2DecoderLayer(
        (self_attn): DeepseekV2Attention(
          (q_proj): FP8StaticLinear()
          (kv_a_proj_with_mqa): FP8StaticLinear()
          (kv_a_layernorm): DeepseekV2RMSNorm()
          (kv_b_proj): FP8StaticLinear()
          (o_proj): FP8StaticLinear()
          (rotary_emb): DeepseekV2YarnRotaryEmbedding()
        )
        (mlp): DeepseekV2MLP(
          (gate_proj): FP8StaticLinear()
          (up_proj): FP8StaticLinear()
          (down_proj): FP8StaticLinear()
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekV2RMSNorm()
        (post_attention_layernorm): DeepseekV2RMSNorm()
      )
      (1-26): 26 x DeepseekV2DecoderLayer(
        (self_attn): DeepseekV2Attention(
          (q_proj): FP8StaticLinear()
          (kv_a_proj_with_mqa): FP8StaticLinear()
          (kv_a_layernorm): DeepseekV2RMSNorm()
          (kv_b_proj): FP8StaticLinear()
          (o_proj): FP8StaticLinear()
          (rotary_emb): DeepseekV2YarnRotaryEmbedding()
        )
        (mlp): DeepseekV2MoE(
          (experts): ModuleList(
            (0-63): 64 x DeepseekV2MLP(
              (gate_proj): FP8StaticLinear()
              (up_proj): FP8StaticLinear()
              (down_proj): FP8StaticLinear()
              (act_fn): SiLU()
            )
          )
          (gate): MoEGate()
          (shared_experts): DeepseekV2MLP(
            (gate_proj): FP8StaticLinear()
            (up_proj): FP8StaticLinear()
            (down_proj): FP8StaticLinear()
            (act_fn): SiLU()
          )
        )
        (input_layernorm): DeepseekV2RMSNorm()
        (post_attention_layernorm): DeepseekV2RMSNorm()
      )
    )
    (norm): DeepseekV2RMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=102400, bias=False)
)
Saving the model to /path-to-models/DeepSeek-Coder-V2-Lite-Instruct-FP8

What can I do to quantize correctly this kind of model ?

@mgoin
Copy link
Member

mgoin commented Jul 10, 2024

Hey @Syst3m1cAn0maly we don't support quantization in vLLM for non-Mixtral MoEs yet. We are currently undergoing a refactor to support Qwen2 and DeepSeek-V2 vllm-project/vllm#6088

@Jiayi-Pan
Copy link

Thank you for the efforts. Looking forward to FP8 support for DSv2❤️

@robertgshaw2-neuralmagic

Thank you for the efforts. Looking forward to FP8 support for DSv2❤️

Working on this today

@robertgshaw2-neuralmagic
Copy link

robertgshaw2-neuralmagic commented Jul 14, 2024

@Syst3m1cAn0maly
Copy link
Author

Thanks a lot.
I will try as soon as possible.
Do I need to change the settings to quantize this model properly with AutoFP8 or should it work as-is ? (I saw there was a specific setting for Mixtral models regarding MoE gates)

@robertgshaw2-neuralmagic
Copy link

robertgshaw2-neuralmagic commented Jul 14, 2024

You need to skip the routing gate:

# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(
    quant_method="fp8", 
    activation_scheme="static",
    # skip the lm head and expert gate
    ignore_patterns=["re:.*lm_head", "re:.*gate.weight"],)

The other thing I'm not sure about is the following layers:

self_attn.kv_a_proj_with_mqa
self_attn.kv_b_proj

Im working on seeing how sensitive they are now

@Syst3m1cAn0maly
Copy link
Author

Thanks, I will try with these settings.

@robertgshaw2-neuralmagic

FYI - config above is good. But needed one more tweak on vllm side.

@Syst3m1cAn0maly
Copy link
Author

@robertgshaw2-neuralmagic thanks a lot for the work

@Syst3m1cAn0maly
Copy link
Author

I tested today and it now works as expected, thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants