From 675759a02cff446d9f152e2c801ca98079e51af4 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 24 Oct 2024 16:32:15 +0800 Subject: [PATCH] [torch.compile] expanding support and fix allgather compilation (#9637) Signed-off-by: youkaichao Co-authored-by: youkaichao Signed-off-by: Tyler Michael Smith --- vllm/distributed/parallel_state.py | 7 ++++++- vllm/model_executor/models/gpt_bigcode.py | 2 ++ vllm/model_executor/models/gpt_j.py | 2 ++ vllm/model_executor/models/gpt_neox.py | 2 ++ vllm/model_executor/models/granite.py | 2 ++ vllm/model_executor/models/internlm2.py | 2 ++ 6 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ab47d62921d2c..ec39856b6f67c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -392,8 +392,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * world_size, ) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, + output_tensor = torch.empty(output_size, dtype=input_.dtype, device=input_.device) # All-gather. @@ -401,6 +405,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_, group=self.device_group) # Reshape + output_tensor = output_tensor.reshape((world_size, ) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 6c4a04667c5da..24c79a8855475 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -187,6 +188,7 @@ def forward( return hidden_states +@support_torch_compile class GPTBigCodeModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index d40bf8c88ee19..0451d16b6c738 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -174,6 +175,7 @@ def forward( return hidden_states +@support_torch_compile class GPTJModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 23a1ca06cc69e..1bccef7a5f173 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -187,6 +188,7 @@ def forward( return hidden_states +@support_torch_compile class GPTNeoXModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index dcf4f5b27704a..5a397ed8ff6a0 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -28,6 +28,7 @@ from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -254,6 +255,7 @@ def forward( return hidden_states +@support_torch_compile class GraniteModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index f6cde44e9d83d..9a77e48626ca5 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -7,6 +7,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -230,6 +231,7 @@ def forward( return hidden_states, residual +@support_torch_compile class InternLM2Model(nn.Module): def __init__(