-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize zero3 fetch params using all_reduce #5420
Optimize zero3 fetch params using all_reduce #5420
Conversation
deepcharm
commented
Apr 16, 2024
- Use all_reduce instead of all_gather to fetch module parameters. This improves performance by reducing the overhead of concatenation and slicing, which are no longer required.
- Instead, all tensors views are created prior to the collective (all_reduce), so upon its completion only the parameter status is updated.
- The behavior is enabled via a new boolean flag under the section "zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
- By default the optimization is not enabled.
* Use all_reduce instead of all_gather to fetch module parameters. This reduces overhead of concatenation and slicing, which are no longer required. * All tensors views are created prior to the collective (all_reduce), so upon its completion only the parameter status is updated. * The behavior is enabled via a new boolean flag under the section "zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true } * By default the optimization is not enabled.
@deepcharm, thanks for this interesting approach. Can you share some observed performance gains? |
@tjruwase We have observed around 9% performance gain on HPU in BERT workloads. |
Hi @deepcharm Thx for the PR. Just curious why allreduce could be faster than allgather? allreduce basically is doing reduce-scatter + all-gather. Could we just make allgather as coalesced version to remove the overhead of concatenation and slicing? |
Hi @GuanhuaWang, you're right the proposed approach indeed adds some communication overhead. The main idea is to re-arrange the layout of the sharded pieces in the flat buffer to achieve overall perf boost. Hopefully, the attached slides below help clarify the benefits (less Host side overhead, smaller memory peak, etc). 1) Current Approach2) Proposed Optimization3) Comparison |
Hi @deepcharm , these slides are cool and make sense to me. But as 2) Proposed Optimization, it showed removing unnecessay data concat© by avoiding params interleaving of allgather (Not allreduce). Allreduce is what confuses me, we don't do any sum/avg operation on collected weights right? |
@deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks! |
@deepcharm, I was not aware that narrow, cat, copy operations on device tensors incurred high CPU overhead. I will like to learn more. Can you share the reason? How did you discover this? Can you share some repro/test code for this? Thanks! @tjruwase, we've seen this phenomenon in large models where looping over many params causes significant CPU overhead. |
@deepcharm, very interesting, thanks for the explanation. I look forward to learning more from the repro script. I think it might be a great documentation for performance debugging of zero3 on accelerators. |
Hi @tjruwase, for some reason the PR has been removed from the merge-queue. Can you please re-add it? Thanks |
* Use all_reduce instead of all_gather to fetch module parameters. This improves performance by reducing the overhead of concatenation and slicing, which are no longer required. * Instead, all tensors views are created prior to the collective (all_reduce), so upon its completion only the parameter status is updated. * The behavior is enabled via a new boolean flag under the section "zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true } * By default the optimization is not enabled. Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>