Skip to content

Commit

Permalink
Use PPMissingLayer after rebasing to latest
Browse files Browse the repository at this point in the history
  • Loading branch information
wushidonguc committed Jul 16, 2024
1 parent 38d0834 commit 2de01e3
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.utils import is_hip, print_warning_once

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from .utils import is_pp_missing_parameter, make_layers, PPMissingLayer


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -263,6 +263,8 @@ def __init__(
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config,
Expand All @@ -271,7 +273,7 @@ def __init__(
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = None
self.norm = PPMissingLayer()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand Down Expand Up @@ -386,9 +388,9 @@ def __init__(
config.vocab_size, logit_scale)
self.sampler = Sampler()
else:
self.lm_head = None
self.logits_processor = None
self.sampler = None
self.lm_head = PPMissingLayer()
self.logits_processor = PPMissingLayer()
self.sampler = PPMissingLayer()

def forward(
self,
Expand Down

0 comments on commit 2de01e3

Please sign in to comment.