-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Support telechat2 #10311
[Model] Support telechat2 #10311
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Title: Add Support for TeleChat2 Model Description: Background: TeleChat2 is an open-source large language model developed by China Telecom's Artificial Intelligence Research Institute. It features various parameter scales and functionalities, including Function Call capabilities. Modifications: Model Integration: Added the implementation code for the TeleChat2 model in the vllm/model_executor/models directory. Functional Testing: Conducted inference tests in a local environment to verify the model's proper operation within vLLM. These modifications do not affect vLLM's support for other models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that the model implementation is basically equivalent to Llama except module naming. (Please correct me if I'm wrong)
If so, I think we can simplify the model implementation by mapping weights names.
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, x): | ||
gate_output, _ = self.gate_proj(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge linear using MergedColumnParallelLinear
, see: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L73
Here’s the revised response: Thank you for your feedback and suggestions! I’ve noted your points and will make the necessary changes as soon as possible. Additionally, regarding the model implementation, our architecture has some differences from Llama in terms of bias configurations:
Because of these differences, Llama cannot directly load our model weights, as it supports only uniform bias configurations (all or none). I’ve also addressed the points you raised that required modifications. Please feel free to share any further insights or suggestions! |
If these are the only two differences, it might be possible to integrate this model following PHI-3's approach, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi3.py |
If the only difference is about bias, we can simply set |
Thank you for your detailed explanation! In TeleChat2, the bias is set to False for If we directly rely on the bias setting in Llama’s MLP and Attention, it might introduce some issues. |
Thank you for your feedback and suggestions! As per your advice, I have inherited the Llama class and completed the implementation of the TeleChat2 code. |
BTW, please address the lint errors by running |
Thank you for your help! |
Hmmm, seems that the model is not compatible with
Will take a look tonight to see how we can do some refactor on curret implementation. :( |
yeah... It seems like there might have been some changes to the layer names in LLaMA, as the same code works perfectly fine with vLLM 0.6.4.post1. If that’s the case, I’m a bit concerned that similar compatibility issues might arise again in the future if LLaMA undergoes further updates after we make the necessary adjustments. Do you think there’s a way to address this in a more stable manner moving forward? I’d love to hear your thoughts. |
merge from main
@shunxing12345 I think a possible way is pruning the bias from a LlamaModel. I have drafted a refactored implementation for I have tested it on |
Thank you for your assistance! I have tried your code, but when using vLLM to serve the 35B and 115B models, the process gets stuck at this stage (as shown in the screenshot) for more than 10 minutes without any response. Could you please provide guidance on how to resolve this issue? Your help would be greatly appreciated! |
Perhaps you can try offline inference instead. Just run: VLLM_USE_MODELSCOPE=True python examples/offline_inference_cli.py --model TeleAI/TeleChat2-3B --max-model-len 4096 --trust-remote-code |
Thank you so much for your guidance and help! Apologies for my earlier mistake—now I have successfully gotten TeleChat2 to run across all sizes. Your assistance has been incredibly helpful. Thank you again!😊 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding this model!
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]> Co-authored-by: xiangw2 <[email protected]> Co-authored-by: Isotr0py <[email protected]> Signed-off-by: Andrew Feldman <[email protected]>
Signed-off-by: Isotr0py <[email protected]> Co-authored-by: xiangw2 <[email protected]> Co-authored-by: Isotr0py <[email protected]>
Related #5776
FIX #6503