Skip to content

Commit

Permalink
[Distributed][PP] only create embedding & lm head when necessary (vll…
Browse files Browse the repository at this point in the history
…m-project#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
  • Loading branch information
wushidonguc authored and fialhocoelho committed Jul 19, 2024
1 parent 7668b56 commit b63d466
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.utils import is_hip

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -257,17 +257,24 @@ def __init__(
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand Down Expand Up @@ -360,26 +367,30 @@ def __init__(
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()

def forward(
self,
Expand Down

0 comments on commit b63d466

Please sign in to comment.