Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize zero3 fetch params using all_reduce #5420

Merged
merged 5 commits into from
May 20, 2024

Conversation

deepcharm
Copy link
Contributor

  • Use all_reduce instead of all_gather to fetch module parameters. This improves performance by reducing the overhead of concatenation and slicing, which are no longer required.
  • Instead, all tensors views are created prior to the collective (all_reduce), so upon its completion only the parameter status is updated.
  • The behavior is enabled via a new boolean flag under the section "zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
  • By default the optimization is not enabled.

* Use all_reduce instead of all_gather to fetch module parameters.
  This reduces overhead of concatenation and slicing, which are no
  longer required.
* All tensors views are created prior to the collective (all_reduce),
  so upon its completion only the parameter status is updated.
* The behavior is enabled via a new boolean flag under the section
  "zero_optimization": {
      "stage3_use_all_reduce_for_fetch_params": true
  }
* By default the optimization is not enabled.
@tjruwase tjruwase requested review from GuanhuaWang and tohtana and removed request for mrwyattii April 16, 2024 14:48
@tjruwase
Copy link
Contributor

@deepcharm, thanks for this interesting approach. Can you share some observed performance gains?

@deepcharm
Copy link
Contributor Author

@deepcharm, thanks for this interesting approach. Can you share some observed performance gains?

@tjruwase We have observed around 9% performance gain on HPU in BERT workloads.

@GuanhuaWang
Copy link
Member

GuanhuaWang commented Apr 16, 2024

Hi @deepcharm

Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

@deepcharm
Copy link
Contributor Author

Hi @deepcharm

Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

Hi @GuanhuaWang, you're right the proposed approach indeed adds some communication overhead. The main idea is to re-arrange the layout of the sharded pieces in the flat buffer to achieve overall perf boost.

Hopefully, the attached slides below help clarify the benefits (less Host side overhead, smaller memory peak, etc).
Please let me know if that answers your questions.

1) Current Approach

Current_Approach

2) Proposed Optimization

Proposal

3) Comparison

Comparison

@GuanhuaWang
Copy link
Member

GuanhuaWang commented Apr 23, 2024

Hi @deepcharm
Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing?

Hi @GuanhuaWang, you're right the proposed approach indeed adds some communication overhead. The main idea is to re-arrange the layout of the sharded pieces in the flat buffer to achieve overall perf boost.

Hopefully, the attached slides below help clarify the benefits (less Host side overhead, smaller memory peak, etc). Please let me know if that answers your questions.

1) Current Approach

Current_Approach

2) Proposed Optimization

Proposal

3) Comparison

Comparison

Hi @deepcharm , these slides are cool and make sense to me. But as 2) Proposed Optimization, it showed removing unnecessay data concat&copy by avoiding params interleaving of allgather (Not allreduce). Allreduce is what confuses me, we don't do any sum/avg operation on collected weights right?

@tjruwase
Copy link
Contributor

image

@deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks!

@deepcharm
Copy link
Contributor Author

image @deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks!

@tjruwase, we've seen this phenomenon in large models where looping over many params causes significant CPU overhead.
Possibly this issue is more specific for accelerators such as HPU.
We will create a repro script and share with you.

@tjruwase
Copy link
Contributor

tjruwase commented May 7, 2024

@tjruwase, we've seen this phenomenon in large models where looping over many params causes significant CPU overhead.
Possibly this issue is more specific for accelerators such as HPU.
We will create a repro script and share with you.

@deepcharm, very interesting, thanks for the explanation. I look forward to learning more from the repro script. I think it might be a great documentation for performance debugging of zero3 on accelerators.

@tjruwase tjruwase added this pull request to the merge queue May 7, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to no response for status checks May 7, 2024
@deepcharm
Copy link
Contributor Author

Hi @tjruwase, for some reason the PR has been removed from the merge-queue. Can you please re-add it? Thanks

@tjruwase tjruwase added this pull request to the merge queue May 13, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks May 13, 2024
@tjruwase tjruwase added this pull request to the merge queue May 20, 2024
Merged via the queue into microsoft:master with commit 49df8d8 May 20, 2024
14 checks passed
sfc-gh-reyazda pushed a commit to Snowflake-Labs/DeepSpeed that referenced this pull request Jun 10, 2024
* Use all_reduce instead of all_gather to fetch module parameters. This
improves performance by reducing the overhead of concatenation and
slicing, which are no longer required.
* Instead, all tensors views are created prior to the collective
(all_reduce), so upon its completion only the parameter status is
updated.
* The behavior is enabled via a new boolean flag under the section
"zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
* By default the optimization is not enabled.

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants