Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support telechat2 #10311

Merged
merged 27 commits into from
Nov 27, 2024
Merged

[Model] Support telechat2 #10311

merged 27 commits into from
Nov 27, 2024

Conversation

shunxing12345
Copy link
Contributor

@shunxing12345 shunxing12345 commented Nov 14, 2024

Related #5776

FIX #6503

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@shunxing12345
Copy link
Contributor Author

Title: Add Support for TeleChat2 Model

Description:

Background: TeleChat2 is an open-source large language model developed by China Telecom's Artificial Intelligence Research Institute. It features various parameter scales and functionalities, including Function Call capabilities.

Modifications:

Model Integration: Added the implementation code for the TeleChat2 model in the vllm/model_executor/models directory.
Model Registration: Registered the TeleChat2 model in the model_registry.py file to enable recognition and utilization within vLLM.
Compatibility Adjustments: Modified the model's forward() method to ensure compatibility with vLLM's architecture.
Testing:

Functional Testing: Conducted inference tests in a local environment to verify the model's proper operation within vLLM.
Performance Evaluation: Assessed inference speed and resource utilization, confirming that performance meets expectations.
Compatibility:

These modifications do not affect vLLM's support for other models.
No new dependencies have been introduced.

xiangw2 added 2 commits November 14, 2024 11:15
Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the model implementation is basically equivalent to Llama except module naming. (Please correct me if I'm wrong)

If so, I think we can simplify the model implementation by mapping weights names.

vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
self.act_fn = SiluAndMul()

def forward(self, x):
gate_output, _ = self.gate_proj(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge linear using MergedColumnParallelLinear, see: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L73

vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
@DarkLight1337 DarkLight1337 changed the title [Mode]support telechat2 [Model] Support telechat2 Nov 14, 2024
@shunxing12345
Copy link
Contributor Author

Here’s the revised response:

Thank you for your feedback and suggestions!

I’ve noted your points and will make the necessary changes as soon as possible. Additionally, regarding the model implementation, our architecture has some differences from Llama in terms of bias configurations:

  • Q and K components include biases, but V does not.
  • MLP layers also contain biases.

Because of these differences, Llama cannot directly load our model weights, as it supports only uniform bias configurations (all or none).

I’ve also addressed the points you raised that required modifications. Please feel free to share any further insights or suggestions!

xiangw2 added 2 commits November 20, 2024 15:27
vllm/model_executor/models/registry.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
@jeejeelee
Copy link
Collaborator

jeejeelee commented Nov 20, 2024

Q and K components include biases, but V does not.

Here’s the revised response:

Thank you for your feedback and suggestions!

I’ve noted your points and will make the necessary changes as soon as possible. Additionally, regarding the model implementation, our architecture has some differences from Llama in terms of bias configurations:

  • Q and K components include biases, but V does not.
  • MLP layers also contain biases.

Because of these differences, Llama cannot directly load our model weights, as it supports only uniform bias configurations (all or none).

I’ve also addressed the points you raised that required modifications. Please feel free to share any further insights or suggestions!

If these are the only two differences, it might be possible to integrate this model following PHI-3's approach, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi3.py

@Isotr0py
Copy link
Collaborator

  • Q and K components include biases, but V does not.
  • MLP layers also contain biases.

If the only difference is about bias, we can simply set bias=True on Llama's Attention and MLP, then handle this in weigts loading by creating full zeros bias for v_proj and MLP that might not have bias.

@shunxing12345
Copy link
Contributor Author

shunxing12345 commented Nov 21, 2024

Thank you for your detailed explanation! In TeleChat2, the bias is set to False for gate_up_proj in the MLP, while it is set to True for down_proj. Additionally, in the Attention module, I previously misspoke. The bias is actually False for qkv_proj, but it is True for dense.

If we directly rely on the bias setting in Llama’s MLP and Attention, it might introduce some issues.

@shunxing12345
Copy link
Contributor Author

Q and K components include biases, but V does not.

Here’s the revised response:
Thank you for your feedback and suggestions!
I’ve noted your points and will make the necessary changes as soon as possible. Additionally, regarding the model implementation, our architecture has some differences from Llama in terms of bias configurations:

  • Q and K components include biases, but V does not.
  • MLP layers also contain biases.

Because of these differences, Llama cannot directly load our model weights, as it supports only uniform bias configurations (all or none).
I’ve also addressed the points you raised that required modifications. Please feel free to share any further insights or suggestions!

If these are the only two differences, it might be possible to integrate this model following PHI-3's approach, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi3.py

Thank you for your feedback and suggestions!

As per your advice, I have inherited the Llama class and completed the implementation of the TeleChat2 code.

@Isotr0py
Copy link
Collaborator

BTW, please address the lint errors by running bash ./format.sh and add this model to docs/source/models/supported_models.rst. :)

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 22, 2024
@shunxing12345
Copy link
Contributor Author

BTW, please address the lint errors by running bash ./format.sh and add this model to docs/source/models/supported_models.rst. :)
image

Hi,

I have already run bash format.sh, and the result is shown in the attached screenshot—everything passed the format check. Additionally, I have updated the supported_models.rst file as well.

Please let me know if there’s anything else needed.

Thank you!

@shunxing12345
Copy link
Contributor Author

BTW, please address the lint errors by running bash ./format.sh and add this model to docs/source/models/supported_models.rst. :)

Thank you for your help!

@Isotr0py
Copy link
Collaborator

Isotr0py commented Nov 26, 2024

Hmmm, seems that the model is not compatible with torch.compile, because we're overriding initialized module in inherent:

INFO 11-26 09:26:29 config.py:1869] Downcasting torch.float32 to torch.float16.
INFO 11-26 09:26:36 config.py:373] This model supports multiple tasks: {'generate', 'embedding'}. Defaulting to 'generate'.
WARNING 11-26 09:26:36 arg_utils.py:1105] [DEPRECATED] Block manager v1 has been removed, and setting --use-v2-block-manager to True or False has no effect on vLLM behavior. Please remove --use-v2-block-manager in your engine argument. If your use case is not supported by SelfAttnBlockSpaceManager (i.e. block manager v2), please file an issue with detailed information.
INFO 11-26 09:26:36 llm_engine.py:248] Initializing an LLM engine (v0.1.dev3049+g82c2515.d20241019) with config: model='TeleAI/TeleChat2-3B', speculative_config=None, tokenizer='TeleAI/TeleChat2-3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=TeleAI/TeleChat2-3B, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None, pooler_config=None,compilation_config=CompilationConfig(level=0, backend='', custom_ops=[], splitting_ops=['vllm.unified_attention', 'vllm.unified_v1_flash_attention'], use_inductor=True, inductor_specialize_for_cudagraph_no_more_than=None, inductor_compile_sizes={}, inductor_compile_config={}, inductor_passes={}, use_cudagraph=False, cudagraph_num_of_warmups=0, cudagraph_capture_sizes=None, cudagraph_copy_inputs=False, pass_config=PassConfig(dump_graph_stages=[], dump_graph_dir=PosixPath('.'), enable_fusion=True, enable_reshape=True), compile_sizes=<function PrivateAttr at 0x7f3047efbac0>, capture_sizes=<function PrivateAttr at 0x7f3047efbac0>, enabled_custom_ops=Counter(), disabled_custom_ops=Counter(), static_forward_context={})
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
WARNING 11-26 09:26:38 tokenizer.py:174] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
INFO 11-26 09:26:43 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 11-26 09:26:43 selector.py:129] Using XFormers backend.
INFO 11-26 09:26:54 model_runner.py:1100] Starting to load model TeleAI/TeleChat2-3B...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/kaggle/working/vllm/examples/offline_inference_cli.py", line 80, in <module>
[rank0]:     main(args)
[rank0]:   File "/kaggle/working/vllm/examples/offline_inference_cli.py", line 37, in main
[rank0]:     llm = LLM(**asdict(engine_args))
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/utils.py", line 1052, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 225, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 574, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 335, in __init__
[rank0]:     self.model_executor = executor_class(vllm_config=vllm_config, )
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 36, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 35, in _init_executor
[rank0]:     self.driver_worker.load_model()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/worker/worker.py", line 153, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1102, in load_model
[rank0]:     self.model = get_model(vllm_config=self.vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
[rank0]:     return loader.load_model(vllm_config=vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 339, in load_model
[rank0]:     model = _initialize_model(vllm_config=vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 106, in _initialize_model
[rank0]:     return model_class(vllm_config=vllm_config, prefix=prefix)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 209, in __init__
[rank0]:     self.model = TeleChat2Model(vllm_config=vllm_config,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 139, in __init__
[rank0]:     self.start_layer, self.end_layer, self.layers = make_layers(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/utils.py", line 510, in make_layers
[rank0]:     [PPMissingLayer() for _ in range(start_layer)] + [
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/utils.py", line 511, in <listcomp>
[rank0]:     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 141, in <lambda>
[rank0]:     lambda prefix: TeleChat2DecoderLayer(config=config,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 110, in __init__
[rank0]:     self.self_attn = TeleChat2Attention(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 76, in __init__
[rank0]:     super().__init__(config, hidden_size, num_heads, num_kv_heads,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 171, in __init__
[rank0]:     self.attn = Attention(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/attention/layer.py", line 110, in __init__
[rank0]:     raise ValueError(f"Duplicate layer name: {prefix}")
[rank0]: ValueError: Duplicate layer name: .attn

Will take a look tonight to see how we can do some refactor on curret implementation. :(

@shunxing12345
Copy link
Contributor Author

shunxing12345 commented Nov 26, 2024

Hmmm, seems that the model is not compatible with torch.compile, because we're overriding initialized module in inherent:

INFO 11-26 09:26:29 config.py:1869] Downcasting torch.float32 to torch.float16.
INFO 11-26 09:26:36 config.py:373] This model supports multiple tasks: {'generate', 'embedding'}. Defaulting to 'generate'.
WARNING 11-26 09:26:36 arg_utils.py:1105] [DEPRECATED] Block manager v1 has been removed, and setting --use-v2-block-manager to True or False has no effect on vLLM behavior. Please remove --use-v2-block-manager in your engine argument. If your use case is not supported by SelfAttnBlockSpaceManager (i.e. block manager v2), please file an issue with detailed information.
INFO 11-26 09:26:36 llm_engine.py:248] Initializing an LLM engine (v0.1.dev3049+g82c2515.d20241019) with config: model='TeleAI/TeleChat2-3B', speculative_config=None, tokenizer='TeleAI/TeleChat2-3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=TeleAI/TeleChat2-3B, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None, pooler_config=None,compilation_config=CompilationConfig(level=0, backend='', custom_ops=[], splitting_ops=['vllm.unified_attention', 'vllm.unified_v1_flash_attention'], use_inductor=True, inductor_specialize_for_cudagraph_no_more_than=None, inductor_compile_sizes={}, inductor_compile_config={}, inductor_passes={}, use_cudagraph=False, cudagraph_num_of_warmups=0, cudagraph_capture_sizes=None, cudagraph_copy_inputs=False, pass_config=PassConfig(dump_graph_stages=[], dump_graph_dir=PosixPath('.'), enable_fusion=True, enable_reshape=True), compile_sizes=<function PrivateAttr at 0x7f3047efbac0>, capture_sizes=<function PrivateAttr at 0x7f3047efbac0>, enabled_custom_ops=Counter(), disabled_custom_ops=Counter(), static_forward_context={})
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
WARNING 11-26 09:26:38 tokenizer.py:174] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
Downloading Model to directory: /root/.cache/modelscope/hub/TeleAI/TeleChat2-3B
INFO 11-26 09:26:43 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 11-26 09:26:43 selector.py:129] Using XFormers backend.
INFO 11-26 09:26:54 model_runner.py:1100] Starting to load model TeleAI/TeleChat2-3B...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/kaggle/working/vllm/examples/offline_inference_cli.py", line 80, in <module>
[rank0]:     main(args)
[rank0]:   File "/kaggle/working/vllm/examples/offline_inference_cli.py", line 37, in main
[rank0]:     llm = LLM(**asdict(engine_args))
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/utils.py", line 1052, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 225, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 574, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 335, in __init__
[rank0]:     self.model_executor = executor_class(vllm_config=vllm_config, )
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 36, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 35, in _init_executor
[rank0]:     self.driver_worker.load_model()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/worker/worker.py", line 153, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1102, in load_model
[rank0]:     self.model = get_model(vllm_config=self.vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
[rank0]:     return loader.load_model(vllm_config=vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 339, in load_model
[rank0]:     model = _initialize_model(vllm_config=vllm_config)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 106, in _initialize_model
[rank0]:     return model_class(vllm_config=vllm_config, prefix=prefix)
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 209, in __init__
[rank0]:     self.model = TeleChat2Model(vllm_config=vllm_config,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 139, in __init__
[rank0]:     self.start_layer, self.end_layer, self.layers = make_layers(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/utils.py", line 510, in make_layers
[rank0]:     [PPMissingLayer() for _ in range(start_layer)] + [
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/utils.py", line 511, in <listcomp>
[rank0]:     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 141, in <lambda>
[rank0]:     lambda prefix: TeleChat2DecoderLayer(config=config,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 110, in __init__
[rank0]:     self.self_attn = TeleChat2Attention(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/telechat2.py", line 76, in __init__
[rank0]:     super().__init__(config, hidden_size, num_heads, num_kv_heads,
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 171, in __init__
[rank0]:     self.attn = Attention(
[rank0]:   File "/opt/conda/envs/vllm/lib/python3.10/site-packages/vllm/attention/layer.py", line 110, in __init__
[rank0]:     raise ValueError(f"Duplicate layer name: {prefix}")
[rank0]: ValueError: Duplicate layer name: .attn

Will take a look tonight to see how we can do some refactor on curret implementation. :(

yeah... It seems like there might have been some changes to the layer names in LLaMA, as the same code works perfectly fine with vLLM 0.6.4.post1. If that’s the case, I’m a bit concerned that similar compatibility issues might arise again in the future if LLaMA undergoes further updates after we make the necessary adjustments. Do you think there’s a way to address this in a more stable manner moving forward? I’d love to hear your thoughts.

@Isotr0py
Copy link
Collaborator

Isotr0py commented Nov 26, 2024

@shunxing12345 I think a possible way is pruning the bias from a LlamaModel.

I have drafted a refactored implementation for telechat2 on my fork: https://github.com/vllm-project/vllm/blob/4ac891f8ccfc2ef1260cc9023fbcb62d1f876c26/vllm/model_executor/models/telechat2.py.

I have tested it on Telechat2-3B and the 2x VRAM issue should be resolved as well with this implementation. But I haven't handled qkv_bias etc config for bias prune, feel free to copy to this PR and make any modifications that you think necessary :)

@shunxing12345
Copy link
Contributor Author

@shunxing12345 I think a possible way is pruning the bias from a LlamaModel.

I have drafted a refactored implementation for telechat2 on my fork: https://github.com/vllm-project/vllm/blob/4ac891f8ccfc2ef1260cc9023fbcb62d1f876c26/vllm/model_executor/models/telechat2.py.

I have tested it on Telechat2-3B and the 2x VRAM issue should be resolved as well with this implementation. But I haven't handled qkv_bias etc config for bias prune, feel free to copy to this PR and make any modifications that you think necessary :)

image

Thank you for your assistance! I have tried your code, but when using vLLM to serve the 35B and 115B models, the process gets stuck at this stage (as shown in the screenshot) for more than 10 minutes without any response. Could you please provide guidance on how to resolve this issue? Your help would be greatly appreciated!

@Isotr0py
Copy link
Collaborator

I have tried your code, but when using vLLM to serve the 35B and 115B models, the process gets stuck at this stage (as shown in the screenshot) for more than 10 minutes without any response.

Perhaps you can try offline inference instead. Just run:

VLLM_USE_MODELSCOPE=True python examples/offline_inference_cli.py --model TeleAI/TeleChat2-3B --max-model-len 4096 --trust-remote-code

@shunxing12345
Copy link
Contributor Author

I have tried your code, but when using vLLM to serve the 35B and 115B models, the process gets stuck at this stage (as shown in the screenshot) for more than 10 minutes without any response.

Perhaps you can try offline inference instead. Just run:

VLLM_USE_MODELSCOPE=True python examples/offline_inference_cli.py --model TeleAI/TeleChat2-3B --max-model-len 4096 --trust-remote-code

Thank you so much for your guidance and help! Apologies for my earlier mistake—now I have successfully gotten TeleChat2 to run across all sizes. Your assistance has been incredibly helpful. Thank you again!😊

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this model!

vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
Signed-off-by: Isotr0py <[email protected]>
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 27, 2024
tests/models/registry.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
vllm/model_executor/models/telechat2.py Outdated Show resolved Hide resolved
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
@Isotr0py Isotr0py enabled auto-merge (squash) November 27, 2024 09:52
@Isotr0py Isotr0py merged commit 1209261 into vllm-project:main Nov 27, 2024
49 checks passed
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: xiangw2 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: xiangw2 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: Support for Telechat
4 participants