diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 214448bf4320e..ed6360f9d6148 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -171,7 +171,8 @@ def iter_params(self, model_name: str): "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), - # FIXME: Cannot load tokenizer in latest transformers version + # FIXME: Cannot load tokenizer in latest transformers version. + # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), # [Encoder-only] # TODO: Implement PP diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 3bcdb0d87fd52..37c3fa919124e 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -279,6 +280,7 @@ def forward( return hidden_states +@support_torch_compile class OPTModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 0913193f73a48..055407587c598 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -184,7 +185,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -203,9 +203,10 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states, None + return hidden_states +@support_torch_compile class OrionModel(nn.Module): def __init__( @@ -233,8 +234,9 @@ def __init__( prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + make_empty_intermediate_tensors_factory([ + "hidden_states", + ], config.hidden_size)) def forward( self, @@ -246,24 +248,20 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) - residual = None else: - assert intermediate_tensors + assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer( + hidden_states = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, - residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, - "residual": residual }) hidden_states = self.norm(hidden_states) return hidden_states diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index b625d19f6447d..fc9ef15db26c0 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -27,6 +27,7 @@ from transformers import PersimmonConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -209,6 +210,7 @@ def forward( return outputs +@support_torch_compile class PersimmonModel(nn.Module): def __init__(self, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index b9298ed031144..5a3dd3c02b85b 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -29,6 +29,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -263,6 +264,7 @@ def forward( return hidden_states, residual +@support_torch_compile class SolarModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 81dd7c4daa5e9..8f0644bca3e2e 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -193,6 +194,7 @@ def forward( return hidden_states +@support_torch_compile class Starcoder2Model(nn.Module): def __init__(self, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 3bded82033c08..036789642d3c4 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,6 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -220,6 +221,7 @@ def forward( return hidden_states, residual +@support_torch_compile class XverseModel(nn.Module): def __init__( @@ -266,6 +268,7 @@ def forward( residual = None else: hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(