Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support gather operation in NCCL backend #1061

Closed

Conversation

sh0622-kim
Copy link
Contributor

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

I wanted to help with the work in #916.

Modification

Supports gather operation for NCCL backend.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Thanks for your contribution! We should update the unit test here to verify the modification works as expected.

Besides, PyTorch has already supported gather in nccl since version 1.11, and we should also take it into account.

add pytorch version condition
torch_dist.gather(data, gather_list, dst, group)
else:
if get_rank(group) == dst:
gather_list = torch.cuda.comm.gather(data, dst, group)
Copy link
Collaborator

@zhouzaida zhouzaida Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, torch.cuda.comm.gather only supports single-node. Can we use all_gather to implement it as a workaround?

@HAOCHENYE
Copy link
Collaborator

Hi, @sh0622-kim, you can use all_gather to replace torch.cuda.comm.gather when Pytorch version <= 1.11.0

@HAOCHENYE HAOCHENYE added this to the 0.7.4 milestone Apr 23, 2023
gather_list = all_gather_list
else:
gather_list = []
gather_list = all_gather(data, group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_gather should be called at all ranks otherwise the program will be blocked. We should only return the gathered list at the main rank, and return an empty list at other ranks.

Copy link

codecov bot commented Sep 26, 2024

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@8bf1eca). Learn more about missing BASE report.

Files with missing lines Patch % Lines
mmengine/dist/dist.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1061   +/-   ##
=======================================
  Coverage        ?   77.88%           
=======================================
  Files           ?      139           
  Lines           ?    11301           
  Branches        ?     2281           
=======================================
  Hits            ?     8802           
  Misses          ?     2104           
  Partials        ?      395           
Flag Coverage Δ
unittests 77.88% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sh0622-kim sh0622-kim closed this Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants