Skip to content

Conversation

ValeGian
Copy link
Contributor

@ValeGian ValeGian commented Sep 1, 2025

Motivation

This PR aims to add support to train mistral models

Modifications

  • Added distributed model impl for mistral model architecture
  • Added template to registry
  • Added training scripts

Accuracy Test

python tests/test_target_modeling/test_mistral_tp.py
test_mistral_tp (__main__.TestMistralTP) ... rank 1: bind to device 1
rank 0: bind to device 0
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank1]:[W901 17:53:44.564655749 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Saved model to /tmp/tmpp_pm18t_
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank0]:[W901 17:53:44.619454270 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading model from /tmp/tmpp_pm18t_
Loading model from /tmp/tmpp_pm18t_
[rank0]:[W901 17:53:47.376029869 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W901 17:53:48.134036842 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ok

----------------------------------------------------------------------
Ran 1 test in 11.484s

OK
python tests/test_preprocessing.py 
test_assistant_span_boundaries (__main__.TestPreprocessing)
Test that assistant span boundaries are correctly identified without truncation. ... ok
test_conversation_preprocessing_basic (__main__.TestPreprocessing)
Test basic conversation preprocessing with assistant response identification. ... ok
test_multiple_turns_conversation (__main__.TestPreprocessing)
Test conversation with multiple user-assistant turns. ... ok
test_preformatted_conversation (__main__.TestPreprocessing)
Test preprocessing of pre-formatted conversation strings. ... ok

----------------------------------------------------------------------
Ran 4 tests in 3.120s

OK
image

mistralai/Mistral-Small-24B-Instruct-2501 training
image
image

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ValeGian, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Mistral model support into the framework, enabling users to train and utilize Mistral models with distributed capabilities. It includes the necessary model architecture definitions, configuration files, and training scripts, alongside improvements to chat template processing for broader compatibility.

Highlights

  • Mistral Model Support: Added comprehensive support for Mistral models, including their distributed implementation and configuration.
  • Distributed Model Implementation: Introduced distributed (Tensor Parallelism) implementations for Mistral model components, such as MLP, Attention, and the overall CausalLM.
  • Chat Template Enhancements: Improved flexibility in chat template handling by allowing optional end-of-turn tokens and introducing specific end-of-assistant/user tokens.
  • New Configuration and Training Scripts: Included a new configuration file and a training script specifically tailored for the mistral-small-24B-eagle3 model.
  • Unit Testing for Tensor Parallelism: Added a dedicated unit test to validate the correctness of the tensor parallelism implementation for Mistral models.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Mistral models, including the distributed model implementation, a new chat template, and training scripts. The changes are well-structured and include a new test for tensor parallelism which is great. I've found a couple of minor issues with incorrect type hints in the new mistral.py model file, which I've commented on. Correcting these will improve code clarity and maintainability. Overall, this is a solid contribution.

Comment on lines +202 to +204
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for this function is incorrect. The function returns outputs, which is either (hidden_states,) or (hidden_states, self_attn_weights). This corresponds to Union[tuple[torch.FloatTensor], tuple[torch.FloatTensor, Optional[torch.Tensor]]]. The current annotation is misleading as it suggests a different structure, similar to what might be returned if past_key_value were part of the output, which it is not.

    ) -> Union[tuple[torch.FloatTensor], tuple[torch.FloatTensor, Optional[torch.Tensor]]]:

@sleepcoo
Copy link
Collaborator

sleepcoo commented Sep 4, 2025

Could you fix the code format @ValeGian

@ValeGian
Copy link
Contributor Author

ValeGian commented Sep 4, 2025

Fix code format

Done with 06cdfeb

@sleepcoo sleepcoo requested a review from ZhengHSI September 8, 2025 11:17
@ZhengHSI
Copy link
Collaborator

May I ask if you ran the training on the device mentioned above?
When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error.
@ValeGian

@ValeGian
Copy link
Contributor Author

May I ask if you ran the training on the device mentioned above? When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error. @ValeGian

The model itself is around 47GB on disk. I ran the training on a node of 8xH200(an AWS instance of p5en.48xlarge)

I just tried on a smaller node and was able to run a test training on 2xH100 by just modifying the examples/run_mistral_small_24B_eagle3_online.sh so that --tp 2

bash examples/run_mistral_small_24B_eagle3_online.sh
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
rank 0: bind to device 0
rank 0: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 0: Initialized distributed environment
rank 0: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
rank 1: bind to device 1
rank 1: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 1: Initialized distributed environment
rank 1: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 48210.39it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 26259.16it/s]
rank 1: Initialized target model
rank 0: Initialized target model
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 34239.22it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 18139.31it/s]
rank 0: Initialized draft model
rank 1: Initialized draft model
dataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank1]:[W912 10:54:26.648956214 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank0]:[W912 10:54:26.764850022 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
rank 0: Initialized train dataloader
rank 0: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 0: Loaded vocab mappingdataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl

/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
  warnings.warn(
rank 0: Initialized Eagle3 FSDP model
rank 0: Initialized optimizer and scheduler
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
rank 1: Initialized train dataloader
rank 1: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 1: Loaded vocab mapping
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
  warnings.warn(
rank 1: Initialized Eagle3 FSDP model
rank 1: Initialized optimizer and scheduler
Starting training from epoch 0
Training Epoch 0:   0%|                                                                                                                                    | 84/120675 [00:22<6:29:14,  5.16it/s, loss=0.00, acc=0.00]

with

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08             Driver Version: 570.148.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:07:00.0 Off |                    0 |
| N/A   51C    P0            587W /  700W |   67691MiB /  81559MiB |     99%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:08:00.0 Off |                    0 |
| N/A   63C    P0            592W /  700W |   67691MiB /  81559MiB |     93%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A           22069      C   ...u/SpecForge/.venv/bin/python3      67682MiB |
|    1   N/A  N/A           22070      C   ...u/SpecForge/.venv/bin/python3      67682MiB |
+-----------------------------------------------------------------------------------------+

Do you want me to update the examples/run_mistral_small_24B_eagle3_online.sh so that tp is 2?

@ValeGian
Copy link
Contributor Author

ValeGian commented Sep 12, 2025

@ZhengHSI consider that the training for which I reported the curves was optimized to run on the 8xH200 node, the complete set of parameters found on MLflow was

Name | Value
attention_backend | flex_attention
batch_size | 1
build_dataset_num_proc | 192
cache_dir | /home/ubuntu/SpecForge/cache
cache_key | None
chat_template | mistral-small-24B
dist_timeout | 20
dp_size | 2
draft_accumulation_steps | 4
draft_global_batch_size | 8
draft_micro_batch_size | 1
draft_model_config | /home/ubuntu/SpecForge/configs/mistral-small-24B-eagle3.json
embedding_key | model.embed_tokens.weight
eval_data_path | None
eval_interval | 1
is_preformatted | False
is_vlm | False
learning_rate | 0.0001
log_steps | 50
max_grad_norm | 0.5
max_length | 2048
max_pixels | 802816
min_pixels | 50176
mlflow_experiment_name | EAGLE3-mistral-Small-24B
mlflow_run_name | None
mlflow_tracking_uri | <MLflow URI>
num_epochs | 2
output_dir | /home/ubuntu/SpecForge/outputs/mistral-Small-24B-eagle3
profile | False
profile_num_steps | 4
profile_record_shapes | False
profile_start_step | 30
report_to | mlflow
resume | False
save_interval | 1
seed | 0
swanlab_key | None
swanlab_name | None
swanlab_project | None
target_model_path | mistralai/Mistral-Small-24B-Instruct-2501
total_steps | None
tp_size | 4
train_data_path | /home/ubuntu/SpecForge/cache/dataset/sharegpt.jsonl
ttt_length | 7
verbose | False
wandb_key | None
wandb_name | None
wandb_project | None
warmup_ratio | 0.015

I didn't upload the updated configuration as I saw that in the examples folder you keep pretty much the same configuration for every training script, even for larger models such as meta-llama/Llama-4-Scout-17B-16E

torchrun \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_eagle3_online.py \
    --target-model-path meta-llama/Llama-4-Scout-17B-16E \
    --draft-model-config $ROOT_DIR/configs/llama4-scout-17B-16E-eagle3.json \
    --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
    --output-dir $ROOT_DIR/outputs/llama4-scout-17B-16E-eagle3 \
    --num-epochs 10 \
    --batch-size 1 \
    --learning-rate 1e-4 \
    --max-length 2048 \
    --chat-template llama4 \
    --cache-dir $ROOT_DIR/cache \
    --embedding-key language_model.model.embed_tokens.weight \
    --tp-size $NUM_GPUS

@ZhengHSI
Copy link
Collaborator

Thanks for your answer. It would be better to update the script — your current script does not set the tp size, which causes tensor parallelism not to be enabled and leads to OOM. Please modify the script accordingly.
@ValeGian

@ZhengHSI
Copy link
Collaborator

image In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.

@ValeGian
Copy link
Contributor Author

image In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.

I’ll look into this. It’s odd, since I previously completed a few training runs successfully and shared the MLflow run for one of them. I'll try once again since a few merges from main + minor commits were done

@ValeGian
Copy link
Contributor Author

@ZhengHSI I confirmed that recent merges from main broke the PR, you can find the fixes in commit ab36686. I verified the correct functioning using visualize_loss_mask.

I also updated the default Tensor Parallelism for the script in commit 26022f1.

Running it on a node with 2 H100 I got

Training Epoch 0:  10%|█████████████▍     | 12578/120675 [33:09<6:16:52,  4.78it/s, loss=2.74, acc=0.42]

Leaving it to run for some steps I got the following MLflow charts
image
image

@ValeGian
Copy link
Contributor Author

@ZhengHSI any update about this?

@ValeGian
Copy link
Contributor Author

ValeGian commented Oct 2, 2025

@ZhengHSI seems like latest merge from main broke the tests

@ValeGian
Copy link
Contributor Author

ValeGian commented Oct 7, 2025

@ZhengHSI is there any action on my side to allow closing this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants