diff --git a/README.md b/README.md
index 66c08d2a..75a2f12a 100644
--- a/README.md
+++ b/README.md
@@ -164,15 +164,26 @@ bash ./examples/run_qwen3_moe_eagle3_online.sh
# train Qwen3-8B
bash ./examples/run_qwen3_dense_eagle3_online.sh
+
+# train Kimi-K2
+bash ./examples/run_kimi_k2_eagle3_online.sh
```
### 💨 Offline Training
-We have provided a simple startup script to train the Eagle3 model for Llama-3.1-8B-Instruct model in an offline manner. You can run the following command to start the training. Almost Everything is the same as the Online Training Step, except that you don't need to configure anything about target model. Instead, you need to pass `--train-hidden-states-path` to the file.
+We have provided a simple startup script to train the Eagle3 model for Llama-3.1-8B-Instruct and Kimi-K2-Instruct model in an offline manner. You can run the following command to start the training. Almost Everything is the same as the Online Training Step, except that you don't need to configure anything about target model. Instead, you need to pass `--train-hidden-states-path` to the file.
+
+Note: The tokenizer automatically obtained for the Kimi-K2-Instruct model is TikTokenTokenizer. This tokenizer is not a fast model supported by Rust and lacks the interfaces required during data processing. You need to first run the script to generate tokenizer.json in the root directory of the Kimi-K2-Instruct model, and then modify AutoTokenizer to PreTrainedTokenizerFast in scripts/prepare_hidden_states.py. After doing this, the framework can be called normally.
+```bash
+python scripts/convert_kimi_tokenizer.py
+```
+
```bash
# make sure you have sharegpt data prepared
bash ./examples/run_llama3_eagle3_offline.sh
+
+bash ./examples/run_kimi_k2_eagle3_offline.sh
```
### 📈 Wandb Integration
diff --git a/configs/kimi-k2-eagle3.json b/configs/kimi-k2-eagle3.json
new file mode 100644
index 00000000..9222cb93
--- /dev/null
+++ b/configs/kimi-k2-eagle3.json
@@ -0,0 +1,39 @@
+{
+ "architectures": [
+ "LlamaForCausalLMEagle3"
+ ],
+ "eagle_config": {
+ "eagle_aux_hidden_state_layer_ids": [
+ 2,
+ 30,
+ 58
+ ],
+ "use_aux_hidden_state": true
+ },
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 163584,
+ "eos_token_id": 163585,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 7168,
+ "initializer_range": 0.02,
+ "intermediate_size": 18432,
+ "max_position_embeddings": 131072,
+ "max_window_layers": 48,
+ "model_type": "kimi_k2",
+ "num_attention_heads": 64,
+ "num_hidden_layers": 1,
+ "num_key_value_heads":64,
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": null,
+ "rope_theta": 1000000.0,
+ "sliding_window": null,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.53.2",
+ "use_cache": true,
+ "use_sliding_window": false,
+ "vocab_size": 163840,
+ "draft_vocab_size": 32000
+}
diff --git a/examples/get_hidden_states.sh b/examples/get_hidden_states.sh
new file mode 100644
index 00000000..1be6c3d1
--- /dev/null
+++ b/examples/get_hidden_states.sh
@@ -0,0 +1,11 @@
+torchrun --nproc_per_node=8 \
+ scripts/prepare_hidden_states.py \
+ --model-path /root/models/Kimi-K2-Instruct \
+ --enable-aux-hidden-states \
+ --data-path /root/script/SpecForge/cache/dataset/test.jsonl \
+ --chat-template kimi_k2 \
+ --max-length 2048 \
+ --tp-size 8 \
+ --batch-size 1 \
+ --mem-frac=0.95 \
+ --num-samples 2000
\ No newline at end of file
diff --git a/examples/run_kimi_k2_eagle3_offline.sh b/examples/run_kimi_k2_eagle3_offline.sh
new file mode 100644
index 00000000..40255166
--- /dev/null
+++ b/examples/run_kimi_k2_eagle3_offline.sh
@@ -0,0 +1,21 @@
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+ROOT_DIR=$(dirname $SCRIPT_DIR)
+
+# train eagle3 for llama3.1-8b
+NUM_GPUS=${1:-8}
+
+torchrun \
+ --standalone \
+ --nproc_per_node $NUM_GPUS \
+ $ROOT_DIR/scripts/train_eagle3_offline.py \
+ --target-model-path /root/models/Kimi-K2-Instruct \
+ --draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
+ --train-data-path $ROOT_DIR/cache/dataset/test.jsonl \
+ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/rows_0-5000 \
+ --output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
+ --num-epochs 10 \
+ --batch-size 1 \
+ --learning-rate 1e-4 \
+ --max-length 2048 \
+ --chat-template kimi_k2 \
+ --cache-dir $ROOT_DIR/cache
\ No newline at end of file
diff --git a/examples/run_kimi_k2_eagle3_online.sh b/examples/run_kimi_k2_eagle3_online.sh
new file mode 100644
index 00000000..0d52a365
--- /dev/null
+++ b/examples/run_kimi_k2_eagle3_online.sh
@@ -0,0 +1,45 @@
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+ROOT_DIR=$(dirname $SCRIPT_DIR)
+
+# train eagle3 for llama3.1-8b
+NUM_GPUS=${1:-8}
+
+torchrun \
+ --standalone \
+ --nproc_per_node $NUM_GPUS \
+ $ROOT_DIR/scripts/train_eagle3_offline.py \
+ --target-model-path /root/models/Kimi-K2-Instruct \
+ --draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
+ --train-data-path $ROOT_DIR/cache/dataset/test.jsonl \
+ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/rows_0-5000 \
+ --output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
+ --num-epochs 10 \
+ --batch-size 1 \
+ --learning-rate 1e-4 \
+ --max-length 2048 \#!/bin/bash
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+ROOT_DIR=$(dirname $SCRIPT_DIR)
+
+# support tp6 train eagle3 for Kimi-K2
+NUM_GPUS=${1:-8}
+
+torchrun \
+ --standalone \
+ --nproc_per_node $NUM_GPUS \
+ $ROOT_DIR/scripts/train_eagle3_online.py \
+ --target-model-path /root/models/Kimi-K2-Instruct \
+ --draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
+ --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
+ --output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
+ --num-epochs 10 \
+ --batch-size 1 \
+ --learning-rate 1e-4 \
+ --max-length 2048 \
+ --chat-template kimi_k2 \
+ --cache-dir $ROOT_DIR/cache \
+ --embedding-key model.embed_tokens.weight \
+ --tp-size $NUM_GPUS
+
+ --chat-template kimi_k2 \
+ --cache-dir $ROOT_DIR/cache
\ No newline at end of file
diff --git a/scripts/convert_kimi_tokenizer.py b/scripts/convert_kimi_tokenizer.py
new file mode 100644
index 00000000..2affeb1c
--- /dev/null
+++ b/scripts/convert_kimi_tokenizer.py
@@ -0,0 +1,110 @@
+import tiktoken
+from tiktoken.load import load_tiktoken_bpe
+from transformers.integrations.tiktoken import convert_tiktoken_to_fast
+from transformers import AutoTokenizer
+
+def load_kimi_encoding(model_path, base_encoding_name="cl100k_base"):
+ """
+ Load Kimi encoding from tiktoken model file
+
+ Args:
+ model_path (str): Path to the tiktoken.model file
+ base_encoding_name (str): Name of base encoding to use as fallback
+
+ Returns:
+ tiktoken.Encoding: Kimi encoding object
+ """
+ try:
+ # Attempt to create encoding directly from file
+ with open(model_path, 'rb') as f:
+ data = f.read() # Read file for potential future use
+
+ # Load BPE ranks using tiktoken's internal function
+ mergeable_ranks = load_tiktoken_bpe(model_path)
+
+ except Exception as e:
+ # If failed, use base encoding + special tokens
+ print(f"Failed to load {model_path} directly: {str(e)}. Using base encoding {base_encoding_name}...")
+ base_encoding = tiktoken.get_encoding(base_encoding_name)
+ mergeable_ranks = base_encoding._mergeable_ranks
+
+ # Kimi's special tokens
+ special_tokens = {
+ "[BOS]": 163584,
+ "[EOS]": 163585,
+ "<|im_end|>": 163586,
+ "<|im_user|>": 163587,
+ "<|im_assistant|>": 163588,
+ "<|start_header_id|>": 163590,
+ "<|end_header_id|>": 163591,
+ "[EOT]": 163593,
+ "<|im_system|>": 163594,
+ "<|tool_calls_section_begin|>": 163595,
+ "<|tool_calls_section_end|>": 163596,
+ "<|tool_call_begin|>": 163597,
+ "<|tool_call_argument_begin|>": 163598,
+ "<|tool_call_end|>": 163599,
+ "<|im_middle|>": 163601,
+ "[UNK]": 163838,
+ "[PAD]": 163839
+ }
+
+ # Create tiktoken.Encoding object
+ return tiktoken.Encoding(
+ name="kimi_k2",
+ pat_str=r"""'(?:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?[\p{L}]+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
+ mergeable_ranks=mergeable_ranks,
+ special_tokens=special_tokens
+ )
+
+def convert_to_fast_tokenizer(encoding, output_dir):
+ """
+ Convert tiktoken encoding to fast tokenizer and save
+
+ Args:
+ encoding (tiktoken.Encoding): Encoding object to convert
+ output_dir (str): Directory to save the fast tokenizer
+ """
+ # Convert to fast tokenizer
+ convert_tiktoken_to_fast(encoding, output_dir)
+ print(f"Conversion completed! Fast tokenizer saved to {output_dir}")
+
+ # Verify conversion result
+ fast_tokenizer = AutoTokenizer.from_pretrained(output_dir)
+ print(f"Verification: Fast tokenizer is {'valid' if fast_tokenizer.is_fast else 'invalid'}")
+ return fast_tokenizer
+
+def main(model_path, output_dir, base_encoding_name="cl100k_base"):
+ """
+ Main function to orchestrate the tokenizer conversion process
+
+ Args:
+ model_path (str): Path to the tiktoken.model file
+ output_dir (str): Directory to save the fast tokenizer
+ base_encoding_name (str): Name of base encoding to use as fallback
+ """
+ try:
+ # Load encoding using Method 1
+ encoding = load_kimi_encoding(model_path, base_encoding_name)
+
+ print(f"Successfully created encoding object: {encoding.name}")
+ print(f"Number of special tokens: {len(encoding._special_tokens)}")
+ print(f"Number of BPE ranks: {len(encoding._mergeable_ranks)}")
+
+ # Convert to fast tokenizer
+ return convert_to_fast_tokenizer(encoding, output_dir)
+
+ except Exception as e:
+ print(f"Conversion failed: {str(e)}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+if __name__ == "__main__":
+ # Configuration parameters - can be modified or passed as command line arguments
+ MODEL_PATH = "moonshotai/Kimi-K2-Instruct/tiktoken.model"
+ OUTPUT_DIR = "moonshotai/Kimi-K2-Instruct"
+ BASE_ENCODING = "cl100k_base"
+
+ # Execute main function with parameters
+ main(MODEL_PATH, OUTPUT_DIR, BASE_ENCODING)
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py
index 8ca9cc78..932ae064 100644
--- a/scripts/prepare_data.py
+++ b/scripts/prepare_data.py
@@ -170,6 +170,5 @@ def main():
if total_skipped_count > 0:
print(f"Skipped {total_skipped_count}/{len(ds)} messages for {args.dataset}")
-
if __name__ == "__main__":
main()
diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py
index 95dc923a..7cb8df60 100644
--- a/scripts/prepare_hidden_states.py
+++ b/scripts/prepare_hidden_states.py
@@ -33,7 +33,7 @@
set_gpu_proc_affinity,
)
from tqdm import tqdm
-from transformers import AutoConfig, AutoTokenizer
+from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerFast
from specforge.data import build_eagle3_dataset
from specforge.utils import print_with_rank, rank_0_priority
@@ -74,7 +74,9 @@ def __init__(self, args, tp_rank: int):
self.server_args = ServerArgs.from_cli_args(args)
self.server_args.enable_return_hidden_states = True
self.server_args.context_length = args.max_length
-
+
+ # 新增:添加trust_remote_code参数
+ self.server_args.trust_remote_code = True # 关键修改
self.server_args.cuda_graph_max_bs = max(self.bench_args.batch_size)
self.server_args.cuda_graph_bs = list(self.bench_args.batch_size)
_set_envs_and_config(self.server_args)
@@ -344,7 +346,7 @@ def main():
dataset = load_dataset("json", data_files=args.data_path)["train"]
if args.num_samples is not None:
dataset = dataset.select(range(args.num_samples))
- tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(args.model_path, trust_remote_code=True, use_fast=True)
cache_key = hashlib.md5(args.data_path.encode()).hexdigest()
with rank_0_priority():
eagle3_dataset = build_eagle3_dataset(
diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py
index 1dfd464d..95717325 100644
--- a/scripts/train_eagle3_offline.py
+++ b/scripts/train_eagle3_offline.py
@@ -10,7 +10,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from tqdm import tqdm
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, PreTrainedTokenizerFast
from specforge import AutoDraftModelConfig, AutoEagle3DraftModel, OfflineEagle3Model
from specforge.data import (
@@ -138,7 +138,7 @@ def main():
print_with_rank(f"Initialized draft model")
# build dataloaders
- tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(args.target_model_path)
# convert to dataloader
cache_key = hashlib.md5(args.train_data_path.encode()).hexdigest()
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index f90e4da8..4782f40f 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -29,7 +29,7 @@
import torch
from datasets import Dataset as HFDataset
from tqdm import tqdm
-from transformers import PreTrainedTokenizer
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from specforge.utils import padding
@@ -44,7 +44,7 @@
# ==============================
# Copied from https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py
def preprocess_conversations(
- tokenizer: PreTrainedTokenizer,
+ tokenizer: PreTrainedTokenizerFast,
conversations: List[Conversation],
chat_template: ChatTemplate,
max_length: int = 2048,
@@ -65,6 +65,8 @@ def preprocess_conversations(
- attention_mask: List of attention masks.
"""
system_prompt = chat_template.system_prompt
+ # This template is only suitable for other models.
+
user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
@@ -72,6 +74,14 @@ def preprocess_conversations(
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
+ # For kimi_k2, use the modified conversation template.
+ # user_message_separator = (
+ # chat_template.user_header
+ # )
+ # assistant_message_separator = (
+ # chat_template.assistant_header
+ # )
+
# prepare result
results = {"input_ids": [], "loss_mask": [], "attention_mask": []}
diff --git a/specforge/data/template.py b/specforge/data/template.py
index 16b0f5d7..94d241ab 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -113,3 +113,68 @@ def get_all_template_names(self) -> List[str]:
end_of_turn_token="<|im_end|>\n",
),
)
+
+
+TEMPLATE_REGISTRY.register(
+ name="qwen3",
+ template=ChatTemplate(
+ # 角色头部标识
+ system_header="<|im_start|>system\n",
+ user_header="<|im_start|>user\n",
+ assistant_header="<|im_start|>assistant\n",
+ system_prompt="You are a helpful assistant.",
+ tool_header="<|im_start|>tool\n",
+
+ # 工具相关标记
+ tools_declaration_prefix="# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n",
+ tools_declaration_suffix="\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n",
+
+ # 工具调用标记
+ tool_call_begin="\n",
+ tool_call_end="\n",
+ tool_response_wrapper="\n{content}\n",
+
+ # 思考过程标记
+ reasoning_wrapper="<|FunctionCallBegin|>\n{reasoning}\n<|FunctionCallEnd|>\n\n",
+
+ # 轮次结束标记
+ end_of_turn_token="<|im_end|>\n",
+
+ # 生成提示
+ generation_prompt="<|im_start|>assistant\n",
+ default_thinking_prompt="<|FunctionCallBegin|>\n\n\n\n"
+ ),
+)
+
+
+TEMPLATE_REGISTRY.register(
+ name="kimi_k2",
+ template=ChatTemplate(
+ # 系统提示相关配置
+ system_header="<|im_system|>system<|im_middle|>",
+ system_prompt="You are a helpful assistant.",
+
+ # 角色前缀配置
+ user_header="<|im_user|>user<|im_middle|>",
+ assistant_header="<|im_assistant|>assistant<|im_middle|>",
+ tool_header="<|im_system|>tool<|im_middle|>",
+
+ # 工具调用相关标记
+ tool_declare_prefix="<|im_system|>tool_declare<|im_middle|>",
+ tool_calls_section_begin="<|tool_calls_section_begin|>",
+ tool_calls_section_end="<|tool_calls_section_end|>",
+ tool_call_begin="<|tool_call_begin|>",
+ tool_call_argument_begin="<|tool_call_argument_begin|>",
+ tool_call_end="<|tool_call_end|>",
+
+ # 媒体内容标记
+ media_start="<|media_start|>image<|media_content|><|media_pad|>",
+ media_end="<|media_end|>",
+
+ # 轮次结束标记
+ end_of_turn_token="<|im_end|>\n",
+
+ # 生成提示前缀
+ generation_prompt="<|im_assistant|>assistant<|im_middle|>"
+ ),
+)
\ No newline at end of file
diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py
index 9124d61a..281e1ff1 100644
--- a/specforge/modeling/auto.py
+++ b/specforge/modeling/auto.py
@@ -11,6 +11,7 @@
LlamaConfig,
PretrainedConfig,
Qwen3MoeConfig,
+ DeepseekV3Config,
)
from specforge.utils import default_torch_dtype
@@ -18,6 +19,7 @@
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.llama4 import Llama4ForCausalLM
from .target.qwen3_moe import Qwen3MoeForCausalLM
+from .target.kimi_k2 import DeepseekV3ForCausalLM
class AutoEagle3DraftModel(AutoModelForCausalLMBase):
@@ -48,6 +50,7 @@ class AutoDistributedTargetModel(AutoModelForCausalLMBase):
_model_mapping = {
Llama4TextConfig: [Llama4ForCausalLM],
Qwen3MoeConfig: [Qwen3MoeForCausalLM],
+ DeepseekV3Config: [DeepseekV3ForCausalLM],
}
@classmethod
@@ -59,7 +62,7 @@ def from_pretrained(
**config_kwargs,
):
config = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, **config_kwargs
+ pretrained_model_name_or_path, trust_remote_code=True, **config_kwargs
)
if isinstance(config, Llama4Config):
diff --git a/specforge/modeling/target/kimi_k2.py b/specforge/modeling/target/kimi_k2.py
new file mode 100644
index 00000000..e4352657
--- /dev/null
+++ b/specforge/modeling/target/kimi_k2.py
@@ -0,0 +1,927 @@
+# coding=utf-8
+# Copyright 2025 Qwen Team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypedDict
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import DeepseekV3Config
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.integrations import use_kernel_forward_from_hub
+from transformers.masking_utils import (
+ create_causal_mask,
+ create_sliding_window_causal_mask,
+)
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_layers import GradientCheckpointingLayer
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from transformers.processing_utils import Unpack
+from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
+
+
+from specforge.distributed import get_tp_group
+from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear
+
+from .base import DistributedTargetModel
+
+logger = logging.get_logger(__name__)
+
+class TransformersKwargs(TypedDict, total=False):
+ """
+ Keyword arguments to be passed to the loss function
+
+ Attributes:
+ num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
+ Number of items in the batch. It is recommended to pass it when
+ you are doing gradient accumulation.
+ output_hidden_states (`Optional[bool]`, *optional*):
+ Most of the models support outputing all hidden states computed during the forward pass.
+ output_attentions (`Optional[bool]`, *optional*):
+ Turn this on to return the intermediary attention scores.
+ output_router_logits (`Optional[bool]`, *optional*):
+ For MoE models, this allows returning the router logits to compute the loss.
+ cumulative_seqlens_q (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for query state.
+ cumulative_seqlens_k (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for key state.
+ max_length_q (`int`, *optional*):
+ Maximum sequence length for query state.
+ max_length_k (`int`, *optional*):
+ Maximum sequence length for key state.
+ """
+
+ num_items_in_batch: Optional["torch.Tensor"]
+ output_hidden_states: Optional[bool]
+ output_attentions: Optional[bool]
+ output_router_logits: Optional[bool]
+ cumulative_seqlens_q: Optional["torch.LongTensor"]
+ cumulative_seqlens_k: Optional["torch.LongTensor"]
+ max_length_q: Optional[int]
+ max_length_k: Optional[int]
+
+
+
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+
+def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ r"""
+ TODO let's just use the original freqcis computation to not have the view
+ transpose + reshape! This is not optimized!
+ Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ b, h, s, d = q.shape
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ b, h, s, d = k.shape
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+
+class DeepseekV3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.num_heads = config.num_attention_heads
+ self.rope_theta = config.rope_theta
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_head_dim = config.qk_head_dim
+
+ self.is_causal = True
+
+ # Add TP support
+ self.tp_group = get_tp_group()
+
+ if self.q_lora_rank is None:
+ self.q_proj = ColumnParallelLinear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
+ else:
+ self.q_a_proj = ColumnParallelLinear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
+ self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
+
+ self.kv_a_proj_with_mqa = ColumnParallelLinear(
+ config.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=config.attention_bias,
+ )
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.num_heads * self.v_head_dim,
+ config.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ self.scaling = self.qk_head_dim ** (-0.5)
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scaling = self.scaling * mscale * mscale
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
+
+ if self.q_lora_rank is None:
+ q_states = self.q_proj(hidden_states)
+ else:
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ q_states = q_states.view(query_shape).transpose(1, 2)
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
+
+ cos, sin = position_embeddings
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
+ else:
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
+
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ # Add all_reduce for TP
+ dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
+ return attn_output, attn_weights
+
+
+class DeepseekV3MLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ # Add TP support
+ self.tp_group = get_tp_group()
+
+ self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ # Add all_reduce for TP
+ dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group)
+ return down_proj
+
+
+class DeepseekV3TopkRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class DeepseekV3MoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = DeepseekV3TopkRouter(config)
+ self.shared_experts = DeepseekV3MLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DeepseekV3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = DeepseekV3MoE(config)
+ else:
+ self.mlp = DeepseekV3MLP(config)
+
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ output_attentions: Optional[bool] = False, # 新增:控制是否返回注意力权重
+ output_router_logits: Optional[bool] = False, # 新增:控制是否返回MoE路由logits
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, ]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions, # 传递参数:控制是否计算注意力权重
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # 全连接/MoE层:处理路由logits
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ mlp_output = self.mlp(hidden_states)
+
+ # 分离隐藏状态和路由logits(仅MoE层有效)
+ if isinstance(mlp_output, tuple):
+ hidden_states, router_logits = mlp_output # MoE层返回(hidden_states, router_logits)
+ else:
+ hidden_states = mlp_output # 普通MLP层仅返回hidden_states
+ router_logits = None # 非MoE层时路由logits为None
+
+ hidden_states = residual + hidden_states
+
+ # 构建动态返回值:根据参数包含额外信息
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,) # 附加注意力权重
+ if output_router_logits and router_logits is not None:
+ outputs += (router_logits,) # 附加路由logits(仅MoE层且参数开启时)
+
+ return outputs
+
+
+class DeepseekV3RotaryEmbedding(nn.Module):
+ def __init__(self, config: DeepseekV3Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+
+@auto_docstring
+class DeepseekV3PreTrainedModel(PreTrainedModel):
+ config_class = DeepseekV3Config # 显式指定配置类
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+
+ # 细化Flash Attention版本支持
+ _supports_flash_attn_3 = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ # 新增缓存相关功能支持
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = False # 根据模型特性调整(如MoE层可能不支持静态缓存)
+
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ # 调用父类初始化逻辑后,添加自定义细化初始化
+ super()._init_weights(module)
+
+ # 原有的路由层初始化保留
+ if isinstance(module, DeepseekV3TopkRouter):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ # 新增其他模块的精细化初始化
+ if isinstance(module, nn.Linear):
+ # 线性层权重用配置的初始化范围,偏置置0
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ # 嵌入层权重初始化,padding位置置0
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, DeepseekV3RMSNorm):
+ # RMSNorm层权重初始化为1(保持输入分布稳定)
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class DeepseekV3Model(DeepseekV3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
+
+ def __init__(self, config: DeepseekV3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ # 参数默认值处理
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ output_router_logits = (
+ output_router_logits
+ if output_router_logits is not None
+ else self.config.output_router_logits
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # 初始化收集变量
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+
+ # 层循环处理
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ # 在每层处理之前保存隐藏状态
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # 调用层处理
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # 处理层输出
+ if isinstance(layer_outputs, tuple):
+ hidden_states = layer_outputs[0]
+
+ # 收集注意力权重
+ if output_attentions and len(layer_outputs) > 1:
+ all_self_attns += (layer_outputs[1],)
+
+ # 收集路由器输出
+ if output_router_logits and len(layer_outputs) > 2:
+ all_router_logits += (layer_outputs[-1],)
+ else:
+ # 如果层返回的不是元组,直接赋值
+ hidden_states = layer_outputs
+
+ # 应用最终的层归一化
+ hidden_states = self.norm(hidden_states)
+
+ # 添加最后一层的隐藏状态
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ print("+"*50)
+ print(f"all_hidden_states 包含的层数: {len(all_hidden_states)}")
+ # 返回 MoE 模型输出
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+@auto_docstring
+class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin, DistributedTargetModel):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = DeepseekV3Model(config)
+ self.vocab_size = config.vocab_size
+ # Use ColumnParallelLinear for lm_head
+ self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
+
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_router_logits = (
+ output_router_logits
+ if output_router_logits is not None
+ else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ # Gather logits from all TP ranks
+ logits = self._gather_tensor(logits, get_tp_group())
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def load_weights(self, state_dict: Dict[str, torch.Tensor]):
+ """Load weights with TP sharding support"""
+ tp_group = get_tp_group()
+
+ updated_state_dict = {}
+ for key, value in state_dict.items():
+ # Ensure that the state dict is a flat dict of keys and tensors
+ if not isinstance(value, torch.Tensor):
+ raise ValueError(
+ f"Expected all values in the state dict to be torch.Tensor. "
+ f"Found {type(value)} instead."
+ )
+
+ module_key = ".".join(key.split(".")[:-1])
+ try:
+ module = self.get_submodule(module_key)
+ except AttributeError:
+ # Skip keys that don't correspond to existing modules
+ continue
+
+ # Handle expert weights specially
+ if "experts.gate_up_proj" in key:
+ gate, up = value.chunk(2, dim=-1)
+ # Shard the gate and up based on tp
+ gate = self._shard_tensor(gate, tp_group, -1)
+ up = self._shard_tensor(up, tp_group, -1)
+ value = torch.cat((gate, up), dim=-1)
+ elif "experts.down_proj" in key:
+ value = self._shard_tensor(value, tp_group, 1)
+ elif isinstance(module, RowParallelLinear) and key.endswith(".weight"):
+ value = self._shard_tensor(value, tp_group, -1)
+ elif isinstance(module, ColumnParallelLinear) and key.endswith(".weight"):
+ value = self._shard_tensor(value, tp_group, 0)
+ elif isinstance(module, ColumnParallelLinear) and key.endswith(".bias"):
+ value = self._shard_tensor(value, tp_group, 0)
+
+ updated_state_dict[key] = value
+
+ # Load state dict
+ self.load_state_dict(updated_state_dict, strict=False)
\ No newline at end of file
diff --git a/specforge/modeling/target/qwen3_moe.py b/specforge/modeling/target/qwen3_moe.py
index 075a0479..3bc7abb3 100644
--- a/specforge/modeling/target/qwen3_moe.py
+++ b/specforge/modeling/target/qwen3_moe.py
@@ -138,6 +138,13 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
+<<<<<<< HEAD
+ self.total_num_kv_heads = config.num_key_value_heads
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+=======
+>>>>>>> upstream/main
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True