Skip to content

Commit

Permalink
Adding "torch compile" annotations to moe models (vllm-project#9758)
Browse files Browse the repository at this point in the history
  • Loading branch information
CRZbulabula authored Oct 28, 2024
1 parent 5f8d807 commit aa0addb
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -360,6 +361,7 @@ def forward(
return hidden_states


@support_torch_compile
class ArcticModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import MixtralConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -245,6 +246,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class MixtralModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -239,6 +240,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class OlmoeModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.configuration_utils import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -429,6 +430,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class PhiMoEModel(nn.Module):

def __init__(
Expand Down

0 comments on commit aa0addb

Please sign in to comment.