From 70a48d2b5bcc64630aae565983942da05d7839a5 Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Mon, 15 Jul 2024 22:14:52 +0000 Subject: [PATCH 1/7] Optimize memory for pipeline parallelism by creating only necessary layers per rank --- vllm/model_executor/models/llama.py | 59 +++++++++++++++++------------ 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f03e34b9e7c92..cfe4cca8cfa3b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -257,17 +257,21 @@ def __init__( (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda: LlamaDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config)) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = None def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -360,26 +364,31 @@ def __init__( cache_config, quant_config, lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) - self.sampler = Sampler() + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + else: + self.lm_head = None + self.logits_processor = None + self.sampler = None def forward( self, From 7c9fa7c9c9ec98441aef1f9c62d8bffa8dd25db7 Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 22:23:25 +0000 Subject: [PATCH 2/7] Use PPMissingLayer after rebasing to latest --- vllm/model_executor/models/llama.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cfe4cca8cfa3b..fbe8e5b70e146 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,7 +50,7 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, PPMissingLayer class LlamaMLP(nn.Module): @@ -263,6 +263,8 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda: LlamaDecoderLayer(config=config, @@ -271,7 +273,7 @@ def __init__( if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: - self.norm = None + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -386,9 +388,9 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() else: - self.lm_head = None - self.logits_processor = None - self.sampler = None + self.lm_head = PPMissingLayer() + self.logits_processor = PPMissingLayer() + self.sampler = PPMissingLayer() def forward( self, From 1552bcf43c5b635f432d49781087ff92656d5efb Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 22:35:48 +0000 Subject: [PATCH 3/7] Formatting --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fbe8e5b70e146..f833e9299fbb6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -385,7 +385,7 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + config.vocab_size, logit_scale) self.sampler = Sampler() else: self.lm_head = PPMissingLayer() From 5428cb03aabd7bd7ec61c6d6b6de0d54dff87a78 Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 22:41:48 +0000 Subject: [PATCH 4/7] More formatting --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f833e9299fbb6..c5abcc6e09c68 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,7 +50,7 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers, PPMissingLayer +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers class LlamaMLP(nn.Module): From 9173c266e513d4e1a42f151a328497acb9c4dcc9 Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 23:28:49 +0000 Subject: [PATCH 5/7] Take care of tie word embedding --- vllm/model_executor/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c5abcc6e09c68..201c246b1eaf7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -257,7 +257,8 @@ def __init__( (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank: + if get_pp_group().is_first_rank or (config.tie_word_embeddings and + get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, From d060127d56d0850d36aee92add004f738d9f0f2c Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 23:31:03 +0000 Subject: [PATCH 6/7] Do not change LogitProcessor and Sampler --- vllm/model_executor/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 201c246b1eaf7..cd6bca6b37bf7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -390,8 +390,6 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() - self.logits_processor = PPMissingLayer() - self.sampler = PPMissingLayer() def forward( self, From 3593de8f7a88100df63cc3e8f2df0baf0ee4c9c7 Mon Sep 17 00:00:00 2001 From: wushidonguc Date: Tue, 16 Jul 2024 23:48:45 +0000 Subject: [PATCH 7/7] Formatting --- vllm/model_executor/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cd6bca6b37bf7..4c434e54cf743 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -257,8 +257,8 @@ def __init__( (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings and - get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -386,7 +386,8 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + config.vocab_size, + logit_scale) self.sampler = Sampler() else: self.lm_head = PPMissingLayer()