Skip to content

Commit

Permalink
[torch.compile] Adding torch compile annotations to some models (vllm…
Browse files Browse the repository at this point in the history
…-project#9876)

Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
2 people authored and Linkun Chen committed Nov 4, 2024
1 parent cac4710 commit 4b6f35d
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Text Generation
- ✅︎
* - :code:`Qwen2ForCausalLM`
- Qwen2
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
- :code:`Qwen/Qwen2-7B-Instruct`, :code:`Qwen/Qwen2-7B`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2MoeForCausalLM`
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def iter_params(self, model_name: str):
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
"bigcode/starcoder2-3b": PPTestSettings.fast(),
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers import FalconConfig as HF_FalconConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -329,6 +330,7 @@ def forward(
return output


@support_torch_compile
class FalconModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from transformers import PhiConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -193,6 +194,7 @@ def forward(
return hidden_states


@support_torch_compile
class PhiModel(nn.Module):

def __init__(self,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
Expand Down Expand Up @@ -549,6 +550,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class QWenModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import Qwen2Config

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -237,6 +238,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class Qwen2Model(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -312,6 +313,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class Qwen2MoeModel(nn.Module):

def __init__(
Expand Down

0 comments on commit 4b6f35d

Please sign in to comment.