Skip to content

Commit

Permalink
[torch.compile] Adding torch compile annotations to some models (#9614)
Browse files Browse the repository at this point in the history
  • Loading branch information
CRZbulabula authored Oct 23, 2024
1 parent fd0e2cf commit 9013e24
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -250,6 +251,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class BaiChuanModel(nn.Module):

def __init__(self,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import BloomConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -218,6 +219,7 @@ def forward(
return output


@support_torch_compile
class BloomModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -250,6 +251,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class CohereModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -311,6 +312,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class ExaoneModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import GemmaConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
Expand Down Expand Up @@ -239,6 +240,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class GemmaModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import GPT2Config

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -182,6 +183,7 @@ def forward(
return hidden_states


@support_torch_compile
class GPT2Model(nn.Module):

def __init__(
Expand Down

0 comments on commit 9013e24

Please sign in to comment.