Skip to content

Commit

Permalink
[torch.compile] Adding torch compile annotations to some models (vllm…
Browse files Browse the repository at this point in the history
…-project#9614)

Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
CRZbulabula authored and mfournioux committed Nov 20, 2024
1 parent 0cee7d4 commit 8b0e554
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -250,6 +251,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class BaiChuanModel(nn.Module):

def __init__(self,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import BloomConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -218,6 +219,7 @@ def forward(
return output


@support_torch_compile
class BloomModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -250,6 +251,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class CohereModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -311,6 +312,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class ExaoneModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import GemmaConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
Expand Down Expand Up @@ -239,6 +240,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class GemmaModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import GPT2Config

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -182,6 +183,7 @@ def forward(
return hidden_states


@support_torch_compile
class GPT2Model(nn.Module):

def __init__(
Expand Down

0 comments on commit 8b0e554

Please sign in to comment.