Skip to content

Commit

Permalink
AllGather Executor Support in NCCL Interface (#393)
Browse files Browse the repository at this point in the history
Co-authored-by: Ziyue Yang <[email protected]>
Co-authored-by: Changho Hwang <[email protected]>
Co-authored-by: Binyang Li <[email protected]>
  • Loading branch information
4 people authored Nov 28, 2024
1 parent 593478e commit d9c297b
Showing 1 changed file with 104 additions and 34 deletions.
138 changes: 104 additions & 34 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ struct ncclComm {
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;
allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t smallMessageSizeBoundary, largeMessageSizeBoundary;
size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary;
size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -279,6 +280,46 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
return ncclSuccess;
}

static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -355,15 +396,39 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->smallMessageSizeBoundary = 16 * (1 << 10);
commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->allReduceLargeMessageSizeBoundary = 1 << 20;

if (getenv("ALLGATHERPKT_IP_JSON_FILE"))
commPtr->allGatherPacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE")));
if (getenv("ALLGATHERPKT_OP_JSON_FILE"))
commPtr->allGatherPacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE")));
if (getenv("ALLGATHER_IP_JSON_FILE"))
commPtr->allGatherIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE")));
if (getenv("ALLGATHER_OP_JSON_FILE"))
commPtr->allGatherOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE")));
if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY"))
commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY"));
else
commPtr->largeMessageSizeBoundary = 1 << 20;
commPtr->allGatherSmallMessageSizeBoundary = (1 << 10);
if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY"))
commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY"));
else
commPtr->allGatherLargeMessageSizeBoundary = 1 << 20;

if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument;
if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary)
return ncclInvalidArgument;

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -483,11 +548,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();

if (bytes < comm->smallMessageSizeBoundary) {
if (bytes < comm->allReduceSmallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->largeMessageSizeBoundary)
if (bytes <= comm->allReduceLargeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else {
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
Expand Down Expand Up @@ -533,36 +598,41 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary)
return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary)
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan;
else {
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::FLOAT16,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::FLOAT32, *plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::UINT32,
*plan, stream);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
Expand Down

0 comments on commit d9c297b

Please sign in to comment.