Skip to content

Commit

Permalink
Support batch_isend_irecv operation in torch plugin by adding group s…
Browse files Browse the repository at this point in the history
…emantics (#28)
  • Loading branch information
yzhang35 authored Feb 7, 2025
1 parent 5ae71db commit 0869dcf
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions plugin/torch/src/backend_flagcx.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "backend_flagcx.hpp"
#include <iostream>
#include <c10/core/DeviceGuard.h>

namespace c10d
{
Expand Down Expand Up @@ -254,14 +255,19 @@ namespace c10d

void BackendFlagcx::groupStart()
{
// flagcxGroupStart();
// ++flagcxActiveGroupCounter_;
#if defined(USE_NVIDIA_ADAPTOR) || defined(USE_ILUVATAR_COREX_ADAPTOR)
initComm(c10::impl::getDeviceGuardImpl(at::DeviceType::CUDA)->getDevice());
#elif defined(USE_CAMBRICON_ADAPTOR)
initComm(c10::impl::getDeviceGuardImpl(at::DeviceType::PrivateUse1)->getDevice());
#endif
flagcxGroupStart();
++flagcxActiveGroupCounter_;
}

void BackendFlagcx::groupEnd()
{
// flagcxGroupEnd();
// --flagcxActiveGroupCounter_;
flagcxGroupEnd();
--flagcxActiveGroupCounter_;
}

void BackendFlagcx::startCoalescing()
Expand Down Expand Up @@ -816,4 +822,4 @@ namespace c10d
m.def("createBackendFlagcx", &BackendFlagcx::createBackendFlagcx);
}

} // namespace c10d
} // namespace c10d

0 comments on commit 0869dcf

Please sign in to comment.