-
Notifications
You must be signed in to change notification settings - Fork 90
Added mistral model support #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ValeGian, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates Mistral model support into the framework, enabling users to train and utilize Mistral models with distributed capabilities. It includes the necessary model architecture definitions, configuration files, and training scripts, alongside improvements to chat template processing for broader compatibility.
Highlights
- Mistral Model Support: Added comprehensive support for Mistral models, including their distributed implementation and configuration.
- Distributed Model Implementation: Introduced distributed (Tensor Parallelism) implementations for Mistral model components, such as MLP, Attention, and the overall CausalLM.
- Chat Template Enhancements: Improved flexibility in chat template handling by allowing optional end-of-turn tokens and introducing specific end-of-assistant/user tokens.
- New Configuration and Training Scripts: Included a new configuration file and a training script specifically tailored for the
mistral-small-24B-eagle3
model. - Unit Testing for Tensor Parallelism: Added a dedicated unit test to validate the correctness of the tensor parallelism implementation for Mistral models.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Mistral models, including the distributed model implementation, a new chat template, and training scripts. The changes are well-structured and include a new test for tensor parallelism which is great. I've found a couple of minor issues with incorrect type hints in the new mistral.py
model file, which I've commented on. Correcting these will improve code clarity and maintainability. Overall, this is a solid contribution.
) -> tuple[ | ||
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for this function is incorrect. The function returns outputs
, which is either (hidden_states,)
or (hidden_states, self_attn_weights)
. This corresponds to Union[tuple[torch.FloatTensor], tuple[torch.FloatTensor, Optional[torch.Tensor]]]
. The current annotation is misleading as it suggests a different structure, similar to what might be returned if past_key_value
were part of the output, which it is not.
) -> Union[tuple[torch.FloatTensor], tuple[torch.FloatTensor, Optional[torch.Tensor]]]:
Could you fix the code format @ValeGian |
Done with 06cdfeb |
May I ask if you ran the training on the device mentioned above? |
The model itself is around 47GB on disk. I ran the training on a node of I just tried on a smaller node and was able to run a test training on
with
Do you want me to update the |
@ZhengHSI consider that the training for which I reported the curves was optimized to run on the 8xH200 node, the complete set of parameters found on MLflow was
I didn't upload the updated configuration as I saw that in the examples folder you keep pretty much the same configuration for every training script, even for larger models such as meta-llama/Llama-4-Scout-17B-16E
|
Thanks for your answer. It would be better to update the script — your current script does not set the tp size, which causes tensor parallelism not to be enabled and leads to OOM. Please modify the script accordingly. |
@ZhengHSI I confirmed that recent merges from main broke the PR, you can find the fixes in commit ab36686. I verified the correct functioning using visualize_loss_mask. I also updated the default Tensor Parallelism for the script in commit 26022f1. Running it on a node with 2 H100 I got
Leaving it to run for some steps I got the following MLflow charts |
@ZhengHSI any update about this? |
@ZhengHSI seems like latest merge from main broke the tests |
@ZhengHSI is there any action on my side to allow closing this PR? |
Motivation
This PR aims to add support to train mistral models
Modifications
Accuracy Test
mistralai/Mistral-Small-24B-Instruct-2501 training


Checklist