From 8b0e5545f3202eb25d0bd4e8102b32b937b1f4e0 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 24 Oct 2024 01:07:48 +0800 Subject: [PATCH] [torch.compile] Adding torch compile annotations to some models (#9614) Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- vllm/model_executor/models/baichuan.py | 2 ++ vllm/model_executor/models/bloom.py | 2 ++ vllm/model_executor/models/commandr.py | 2 ++ vllm/model_executor/models/exaone.py | 2 ++ vllm/model_executor/models/gemma.py | 2 ++ vllm/model_executor/models/gpt2.py | 2 ++ 6 files changed, 12 insertions(+) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 767230aeacc35..f2cfdf8ffd30a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -250,6 +251,7 @@ def forward( return hidden_states, residual +@support_torch_compile class BaiChuanModel(nn.Module): def __init__(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index b2c9e221690b3..77ab7de6165fb 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -218,6 +219,7 @@ def forward( return output +@support_torch_compile class BloomModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 578cd2f04861b..348e6d20f3297 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -28,6 +28,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -250,6 +251,7 @@ def forward( return hidden_states, residual +@support_torch_compile class CohereModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index dfb8fe55d2fb8..4126ceb7117d4 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -29,6 +29,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -311,6 +312,7 @@ def forward( return hidden_states, residual +@support_torch_compile class ExaoneModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 91e556db70a0b..436bd45d53f35 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,6 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -239,6 +240,7 @@ def forward( return hidden_states, residual +@support_torch_compile class GemmaModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 975502340e5f9..3330d84021368 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_world_size) @@ -182,6 +183,7 @@ def forward( return hidden_states +@support_torch_compile class GPT2Model(nn.Module): def __init__(