Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,26 @@ bash ./examples/run_qwen3_moe_eagle3_online.sh

# train Qwen3-8B
bash ./examples/run_qwen3_dense_eagle3_online.sh

# train Kimi-K2
bash ./examples/run_kimi_k2_eagle3_online.sh
```

### 💨 Offline Training

We have provided a simple startup script to train the Eagle3 model for Llama-3.1-8B-Instruct model in an offline manner. You can run the following command to start the training. Almost Everything is the same as the Online Training Step, except that you don't need to configure anything about target model. Instead, you need to pass `--train-hidden-states-path` to the file.
We have provided a simple startup script to train the Eagle3 model for Llama-3.1-8B-Instruct and Kimi-K2-Instruct model in an offline manner. You can run the following command to start the training. Almost Everything is the same as the Online Training Step, except that you don't need to configure anything about target model. Instead, you need to pass `--train-hidden-states-path` to the file.

Note: The tokenizer automatically obtained for the Kimi-K2-Instruct model is TikTokenTokenizer. This tokenizer is not a fast model supported by Rust and lacks the interfaces required during data processing. You need to first run the script to generate tokenizer.json in the root directory of the Kimi-K2-Instruct model, and then modify AutoTokenizer to PreTrainedTokenizerFast in scripts/prepare_hidden_states.py. After doing this, the framework can be called normally.
```bash
python scripts/convert_kimi_tokenizer.py
```


```bash
# make sure you have sharegpt data prepared
bash ./examples/run_llama3_eagle3_offline.sh

bash ./examples/run_kimi_k2_eagle3_offline.sh
```

### 📈 Wandb Integration
Expand Down
39 changes: 39 additions & 0 deletions configs/kimi-k2-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [
2,
30,
58
],
"use_aux_hidden_state": true
},
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 163584,
"eos_token_id": 163585,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"max_position_embeddings": 131072,
"max_window_layers": 48,
"model_type": "kimi_k2",
"num_attention_heads": 64,
"num_hidden_layers": 1,
"num_key_value_heads":64,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.53.2",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 163840,
"draft_vocab_size": 32000
}
11 changes: 11 additions & 0 deletions examples/get_hidden_states.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
torchrun --nproc_per_node=8 \
scripts/prepare_hidden_states.py \
--model-path /root/models/Kimi-K2-Instruct \
--enable-aux-hidden-states \
--data-path /root/script/SpecForge/cache/dataset/test.jsonl \
--chat-template kimi_k2 \
--max-length 2048 \
--tp-size 8 \
--batch-size 1 \
--mem-frac=0.95 \
--num-samples 2000
21 changes: 21 additions & 0 deletions examples/run_kimi_k2_eagle3_offline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# train eagle3 for llama3.1-8b
NUM_GPUS=${1:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_offline.py \
--target-model-path /root/models/Kimi-K2-Instruct \
--draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/test.jsonl \
--train-hidden-states-path $ROOT_DIR/cache/hidden_states/rows_0-5000 \
--output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template kimi_k2 \
--cache-dir $ROOT_DIR/cache
45 changes: 45 additions & 0 deletions examples/run_kimi_k2_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# train eagle3 for llama3.1-8b
NUM_GPUS=${1:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_offline.py \
--target-model-path /root/models/Kimi-K2-Instruct \
--draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/test.jsonl \
--train-hidden-states-path $ROOT_DIR/cache/hidden_states/rows_0-5000 \
--output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp6 train eagle3 for Kimi-K2
NUM_GPUS=${1:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path /root/models/Kimi-K2-Instruct \
--draft-model-config $ROOT_DIR/configs/kimi-k2-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
--output-dir $ROOT_DIR/outputs/Kimi-K2-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template kimi_k2 \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--tp-size $NUM_GPUS

--chat-template kimi_k2 \
--cache-dir $ROOT_DIR/cache
110 changes: 110 additions & 0 deletions scripts/convert_kimi_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
from transformers import AutoTokenizer

def load_kimi_encoding(model_path, base_encoding_name="cl100k_base"):
"""
Load Kimi encoding from tiktoken model file

Args:
model_path (str): Path to the tiktoken.model file
base_encoding_name (str): Name of base encoding to use as fallback

Returns:
tiktoken.Encoding: Kimi encoding object
"""
try:
# Attempt to create encoding directly from file
with open(model_path, 'rb') as f:
data = f.read() # Read file for potential future use

# Load BPE ranks using tiktoken's internal function
mergeable_ranks = load_tiktoken_bpe(model_path)

except Exception as e:
# If failed, use base encoding + special tokens
print(f"Failed to load {model_path} directly: {str(e)}. Using base encoding {base_encoding_name}...")
base_encoding = tiktoken.get_encoding(base_encoding_name)
mergeable_ranks = base_encoding._mergeable_ranks

# Kimi's special tokens
special_tokens = {
"[BOS]": 163584,
"[EOS]": 163585,
"<|im_end|>": 163586,
"<|im_user|>": 163587,
"<|im_assistant|>": 163588,
"<|start_header_id|>": 163590,
"<|end_header_id|>": 163591,
"[EOT]": 163593,
"<|im_system|>": 163594,
"<|tool_calls_section_begin|>": 163595,
"<|tool_calls_section_end|>": 163596,
"<|tool_call_begin|>": 163597,
"<|tool_call_argument_begin|>": 163598,
"<|tool_call_end|>": 163599,
"<|im_middle|>": 163601,
"[UNK]": 163838,
"[PAD]": 163839
}

# Create tiktoken.Encoding object
return tiktoken.Encoding(
name="kimi_k2",
pat_str=r"""'(?:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?[\p{L}]+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens
)

def convert_to_fast_tokenizer(encoding, output_dir):
"""
Convert tiktoken encoding to fast tokenizer and save

Args:
encoding (tiktoken.Encoding): Encoding object to convert
output_dir (str): Directory to save the fast tokenizer
"""
# Convert to fast tokenizer
convert_tiktoken_to_fast(encoding, output_dir)
print(f"Conversion completed! Fast tokenizer saved to {output_dir}")

# Verify conversion result
fast_tokenizer = AutoTokenizer.from_pretrained(output_dir)
print(f"Verification: Fast tokenizer is {'valid' if fast_tokenizer.is_fast else 'invalid'}")
return fast_tokenizer

def main(model_path, output_dir, base_encoding_name="cl100k_base"):
"""
Main function to orchestrate the tokenizer conversion process

Args:
model_path (str): Path to the tiktoken.model file
output_dir (str): Directory to save the fast tokenizer
base_encoding_name (str): Name of base encoding to use as fallback
"""
try:
# Load encoding using Method 1
encoding = load_kimi_encoding(model_path, base_encoding_name)

print(f"Successfully created encoding object: {encoding.name}")
print(f"Number of special tokens: {len(encoding._special_tokens)}")
print(f"Number of BPE ranks: {len(encoding._mergeable_ranks)}")

# Convert to fast tokenizer
return convert_to_fast_tokenizer(encoding, output_dir)

except Exception as e:
print(f"Conversion failed: {str(e)}")
import traceback
traceback.print_exc()
return None

if __name__ == "__main__":
# Configuration parameters - can be modified or passed as command line arguments
MODEL_PATH = "moonshotai/Kimi-K2-Instruct/tiktoken.model"
OUTPUT_DIR = "moonshotai/Kimi-K2-Instruct"
BASE_ENCODING = "cl100k_base"

# Execute main function with parameters
main(MODEL_PATH, OUTPUT_DIR, BASE_ENCODING)
1 change: 0 additions & 1 deletion scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,5 @@ def main():
if total_skipped_count > 0:
print(f"Skipped {total_skipped_count}/{len(ds)} messages for {args.dataset}")


if __name__ == "__main__":
main()
8 changes: 5 additions & 3 deletions scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
set_gpu_proc_affinity,
)
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerFast

from specforge.data import build_eagle3_dataset
from specforge.utils import print_with_rank, rank_0_priority
Expand Down Expand Up @@ -74,7 +74,9 @@ def __init__(self, args, tp_rank: int):
self.server_args = ServerArgs.from_cli_args(args)
self.server_args.enable_return_hidden_states = True
self.server_args.context_length = args.max_length


# 新增:添加trust_remote_code参数
self.server_args.trust_remote_code = True # 关键修改
self.server_args.cuda_graph_max_bs = max(self.bench_args.batch_size)
self.server_args.cuda_graph_bs = list(self.bench_args.batch_size)
_set_envs_and_config(self.server_args)
Expand Down Expand Up @@ -344,7 +346,7 @@ def main():
dataset = load_dataset("json", data_files=args.data_path)["train"]
if args.num_samples is not None:
dataset = dataset.select(range(args.num_samples))
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.model_path, trust_remote_code=True, use_fast=True)
cache_key = hashlib.md5(args.data_path.encode()).hexdigest()
with rank_0_priority():
eagle3_dataset = build_eagle3_dataset(
Expand Down
4 changes: 2 additions & 2 deletions scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from specforge import AutoDraftModelConfig, AutoEagle3DraftModel, OfflineEagle3Model
from specforge.data import (
Expand Down Expand Up @@ -138,7 +138,7 @@ def main():
print_with_rank(f"Initialized draft model")

# build dataloaders
tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.target_model_path)

# convert to dataloader
cache_key = hashlib.md5(args.train_data_path.encode()).hexdigest()
Expand Down
14 changes: 12 additions & 2 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch
from datasets import Dataset as HFDataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from specforge.utils import padding

Expand All @@ -44,7 +44,7 @@
# ==============================
# Copied from https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py
def preprocess_conversations(
tokenizer: PreTrainedTokenizer,
tokenizer: PreTrainedTokenizerFast,
conversations: List[Conversation],
chat_template: ChatTemplate,
max_length: int = 2048,
Expand All @@ -65,13 +65,23 @@ def preprocess_conversations(
- attention_mask: List of attention masks.
"""
system_prompt = chat_template.system_prompt
# This template is only suitable for other models.

user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)

# For kimi_k2, use the modified conversation template.
# user_message_separator = (
# chat_template.user_header
# )
# assistant_message_separator = (
# chat_template.assistant_header
# )

# prepare result
results = {"input_ids": [], "loss_mask": [], "attention_mask": []}

Expand Down
Loading