Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization #6455

Merged
merged 7 commits into from
Jul 17, 2024

Conversation

wushidonguc
Copy link
Contributor

Prior to this change, the Llama model initialization in Pipeline Parallel scenarios would create all components and layers for every rank, leading to unnecessary memory overhead. This pull request optimizes memory usage by creating certain components only on the relevant ranks.

The primary changes include:

  1. Creating the embedding layer on the first rank only.
  2. Creating the norm layer on the last rank only.
  3. Creating the lm head, logits processor, and sampler components on the last rank only.

By selectively creating components on the relevant ranks, this optimization reduces the overall memory footprint. Testing with --pipeline-parallel-size 8 showed up to 25% per-rank memory savings for the Llama-3-70B model. This memory optimization can potentially enable higher throughput by allowing more memory to be utilized for serving.

This optimization enables serving larger models and running inference on resource-constrained environments by leveraging Pipeline Parallelism more efficiently.

Testing:

  • Benchmarks have been conducted to measure the memory savings and performance impact of this optimization.

Please review the changes and provide feedback or suggestions for further improvements.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger fastcheck CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run).

Full CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@sfc-gh-hazhang
Copy link
Contributor

looks nice. I'll do a test on H100 setups and confirm improvement.

Comment on lines 280 to 283
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose norm should take only a small fraction of GPU memory. is it worth the effort?

@youkaichao
Copy link
Member

youkaichao commented Jul 16, 2024

thanks for the contribution! the overall idea looks good to me.

we actually use the minimum of the number of blocks across all processes to determine the final block size, which means it is only effective if we can reduce the lower bound of model memory usage.

In this case, you:

  • removed lm head for the first pp rank
  • removed both word embedding and lm head for the rest ranks
  • removed word embedding for the last rank

Overall, the minimum memory usage of the model is reduced by about the size of a word embedding or an lm head layer (typically the same size).

It would be better if you can move this logic into common layers, so that the rest models can directly benefit from it. I added PPMissingLayer in #6406 recently, which might help. My goal is to reduce the code change to support pp, to make sure the code is not too intrusive and people can find it easy to integrate.

One niche case, is you need to take care of models that use tie embeddings.

@andoorve
Copy link
Collaborator

Thanks for your change! This should bring some improvement. I'll let @sfc-gh-hazhang report back.

@wushidonguc
Copy link
Contributor Author

wushidonguc commented Jul 16, 2024

thanks for the contribution! the overall idea looks good to me.

we actually use the minimum of the number of blocks across all processes to determine the final block size, which means it is only effective if we can reduce the lower bound of model memory usage.

In this case, you:

* removed lm head for the first pp rank

* removed both word embedding and lm head for the rest ranks

* removed word embedding for the last rank

Overall, the minimum memory usage of the model is reduced by about the size of a word embedding or an lm head layer (typically the same size).

It would be better if you can move this logic into common layers, so that the rest models can directly benefit from it. I added PPMissingLayer in #6406 recently, which might help. My goal is to reduce the code change to support pp, to make sure the code is not too intrusive and people can find it easy to integrate.

One niche case, is you need to take care of models that use tie embeddings.

@youkaichao Thanks for the feedback and suggestions!

Instead of modifying VocabParallelEmbedding and ParallelLMHead, I propose keeping the pipeline parallelism logic within the model-specific classes. We can introduce helper methods like get_rank_dependent_embedding and get_rank_dependent_lm_head to encapsulate the rank-aware logic. These helpers would return the appropriate instances based on the current rank.

This way, VocabParallelEmbedding and ParallelLMHead remain generic and reusable across models. The pipeline parallelism logic stays in the model-specific classes, making it easier to maintain and extend.

Please let me know if this approach works for you.

@youkaichao
Copy link
Member

The pipeline parallelism logic stays in the model-specific classes, making it easier to maintain and extend.

this idea looks good to me. I'd like to make the code change as small as possible, and it would be better if we can make the code intuitive.

@wushidonguc
Copy link
Contributor Author

@youkaichao I've rebased the branch on the latest main and integrated the PPMissingLayer component to handle weight skipping. No helper functions are added at this stage, but they can be introduced in a separate pull request if deemed necessary.

Please review the changes and provide any feedback or concerns you may have. I tested the integration of PPMissingLayer locally, and it seems to be working as expected.

Comment on lines +373 to +383
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should only put if-else for this layer, LogitsProcessor and Sampler don't have parameters. Let's make the change as small as possible.

config.hidden_size,
org_num_embeddings=config.vocab_size,
)
if get_pp_group().is_first_rank:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank) ? we need to take care of tie word embedding as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 388 to 390
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.sampler = Sampler()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some niche details, I think it would be better to keep these two lines untouched.

your code should just add if-else around:

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
            quant_config=quant_config,
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your perspective about keeping those lines untouched. However, placing LogitsProcessor and Sampler within the if-else statement avoids unnecessary object creation for ranks other than the last rank.

@youkaichao
Copy link
Member

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 17, 2024
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test failure in https://buildkite.com/vllm/ci-aws/builds/5017#0190be33-93a3-48a3-a443-90c80d0b63df can be ignored, I just disabled them because they are flaky. need to wait for @andoorve to fix.

as long as other tests are passed, we can merge this PR.

thanks for the contribution!

@youkaichao youkaichao merged commit 1d094fd into vllm-project:main Jul 17, 2024
80 of 85 checks passed
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 17, 2024
…m-project#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 19, 2024
…m-project#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
…m-project#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…m-project#6455)

original title: [Distributed][Model] Rank-based Component Creation for Pipeline Parallelism Memory Optimization

Signed-off-by: Alvant <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants