Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f60f458
Add mistral target model
ValeGian Aug 31, 2025
5b83e92
Add mistral to AutoDistributedTargetModel _model_mapping
ValeGian Aug 31, 2025
14582ab
Register mistral templates
ValeGian Aug 31, 2025
37bf89e
Add target model unit test
ValeGian Aug 31, 2025
1678ae5
Add head_dim to test MistralConfig
ValeGian Aug 31, 2025
7d6de07
Remove mistral v0.1 and v0.3 templates, add mistral small 24B template
ValeGian Aug 31, 2025
898e471
Add mistral-small-24B eagle3 config
ValeGian Aug 31, 2025
6b4ac96
Fix wrong chat_template.end_of_turn_token None check
ValeGian Aug 31, 2025
1eb91d7
Test mistral-small-24B preprocessing
ValeGian Aug 31, 2025
8922788
Restore Qwen3-8B preprocessing test
ValeGian Aug 31, 2025
037980f
Add train script for mistral-Small-24B
ValeGian Sep 1, 2025
4d1aeb9
Merge branch 'main' into mistral
ValeGian Sep 1, 2025
877247b
Lint fix
ValeGian Sep 1, 2025
a059679
Fix misleading return type
ValeGian Sep 1, 2025
06cdfeb
Fix code format
ValeGian Sep 4, 2025
fb3cd1d
Merge branch 'main' into mistral
ZhengHSI Sep 15, 2025
ab36686
Fix preprocessing
ValeGian Sep 22, 2025
5ca235e
Merge branch 'main' into mistral
ValeGian Sep 22, 2025
e8a43a2
Fix linting
ValeGian Sep 22, 2025
26022f1
Increase TP to 2 to fit on H100 with 96GB
ValeGian Sep 22, 2025
0744426
Merge branch 'main' into mistral
ZhengHSI Oct 1, 2025
33c4bd7
Set default NUM_GPUS to 2
ValeGian Oct 1, 2025
5d1fd1b
Merge branch 'main' into mistral
ZhengHSI Oct 2, 2025
ebf4e6b
Merge branch 'main' into mistral
ZhengHSI Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions configs/mistral-small-24B-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 32768,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 100000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.47.0",
"use_cache": true,
"vocab_size": 131072,
"draft_vocab_size": 32000
}
23 changes: 23 additions & 0 deletions examples/run_mistral_small_24B_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for mistral-Small-24B
NUM_GPUS=${1:-2}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path mistralai/Mistral-Small-24B-Instruct-2501 \
--draft-model-config $ROOT_DIR/configs/mistral-small-24B-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
--output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \
--num-epochs 2 \
--batch-size 1 \
--tp 2 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template mistral-small-24B \
--cache-dir $ROOT_DIR/cache \
--attention-backend flex_attention
12 changes: 6 additions & 6 deletions specforge/data/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class GeneralParser(Parser):
def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
super().__init__(tokenizer, chat_template)
self.system_prompt = chat_template.system_prompt
self.user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
self.assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
if chat_template.end_of_turn_token:
self.user_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
self.assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
else:
self.user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
self.assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"

def parse(
self,
Expand Down
14 changes: 8 additions & 6 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ def _apply_loss_mask_from_chat_template(
"""
loss_mask = torch.zeros(len(offsets), dtype=torch.long)

user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
if chat_template.end_of_turn_token:
user_message_separator = (
f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
)
assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
else:
user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"

# Find spans of assistant responses using regex
assistant_pattern = (
Expand Down
24 changes: 23 additions & 1 deletion specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ class ChatTemplate(BaseModel):
user_header(str): The header for the user.
system_prompt(str): The system prompt.
end_of_turn_token(str): The end token of a turn of conversation.
If present, end_of_assistant_token and end_of_user_token are ignored.
end_of_assistant_token(str): The end token of an assistant turn of conversation.
end_of_user_token(str): The end token of a user turn of conversation.
"""

assistant_header: str | None
user_header: str | None
system_prompt: str | None
end_of_turn_token: str | None
end_of_turn_token: str | None = None
end_of_assistant_token: str | None = None
end_of_user_token: str | None = None
parser_type: str = "general"


Expand Down Expand Up @@ -105,6 +110,23 @@ def get_all_template_names(self) -> List[str]:
),
)

TEMPLATE_REGISTRY.register(
name="mistral-small-24B",
template=ChatTemplate(
assistant_header="[/INST]",
user_header="[INST]",
system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup "
"headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
"is 2025-08-31. When you're not sure about some information, you say that you don't have the "
"information and don't make up anything. If the user's question is not clear, ambiguous, or "
"does not provide enough context for you to accurately answer the question, you do not try to "
'answer it right away and you rather ask the user to clarify their request (e.g. "What are '
'some good restaurants around me?" => "Where are you?" or "When is the next flight to '
'Tokyo" => "Where do you travel from?")',
end_of_assistant_token="</s>",
),
)

TEMPLATE_REGISTRY.register(
name="qwen",
template=ChatTemplate(
Expand Down
3 changes: 3 additions & 0 deletions specforge/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Llama4Config,
Llama4TextConfig,
LlamaConfig,
MistralConfig,
Phi3Config,
PretrainedConfig,
Qwen2_5_VLConfig,
Expand All @@ -26,6 +27,7 @@
from .target.gpt_oss import GptOssForCausalLM
from .target.llama import LlamaForCausalLM
from .target.llama4 import Llama4ForCausalLM
from .target.mistral import MistralForCausalLM
from .target.phi3 import Phi3ForCausalLM
from .target.qwen2 import Qwen2ForCausalLM
from .target.qwen3 import Qwen3ForCausalLM
Expand Down Expand Up @@ -94,6 +96,7 @@ class AutoDistributedTargetModel(AutoModelForCausalLMBase):
LlamaConfig: [LlamaForCausalLM],
Qwen3Config: [Qwen3ForCausalLM],
Phi3Config: [Phi3ForCausalLM],
MistralConfig: [MistralForCausalLM],
GptOssConfig: [GptOssForCausalLM],
}

Expand Down
Loading
Loading