diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5eddd3609e7f0..8454ceff98201 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,7 +50,7 @@ from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, PPMissingLayer class LlamaMLP(nn.Module): @@ -263,6 +263,8 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda: LlamaDecoderLayer(config=config, @@ -271,7 +273,7 @@ def __init__( if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: - self.norm = None + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -386,9 +388,9 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() else: - self.lm_head = None - self.logits_processor = None - self.sampler = None + self.lm_head = PPMissingLayer() + self.logits_processor = PPMissingLayer() + self.sampler = PPMissingLayer() def forward( self,