From d9c297ba14dd219875dc4e691899c3668c86efbe Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:05:51 -0800 Subject: [PATCH] AllGather Executor Support in NCCL Interface (#393) Co-authored-by: Ziyue Yang Co-authored-by: Changho Hwang Co-authored-by: Binyang Li --- apps/nccl/src/nccl.cu | 138 +++++++++++++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 34 deletions(-) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 2bd2a442..2b7e9736 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -58,7 +58,7 @@ struct ncclComm { std::vector> smSemaphores; std::shared_ptr executor; std::shared_ptr allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan, - allReduceOPPlan; + allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan; std::unordered_map channelInInfos; std::unordered_map channelOutInfos; @@ -66,7 +66,8 @@ struct ncclComm { std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; - size_t smallMessageSizeBoundary, largeMessageSizeBoundary; + size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary; + size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary; uint32_t numScratchBuff; uint32_t buffFlag; }; @@ -279,6 +280,46 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, return ncclSuccess; } +static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { + size_t bytes = sendcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + + // Declarating variables + size_t recvBytes; + CUdeviceptr recvBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); + size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + channelKey recvKey{(void*)recvBasePtr, recvBytes}; + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + mscclpp::DeviceHandle* smChannels = nullptr; + + auto it = comm->channelOutInfos.find(recvKey); + if (it == comm->channelOutInfos.end()) { + std::vector remoteMemories = setupRemoteMemories( + comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); + std::vector channels = + setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); + std::vector> smChannelDeviceHandles; + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } + + smChannels = it->second.smChannelDeviceHandles.get(); + if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { + CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, + NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } else { + CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, + NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } + + return ncclSuccess; +} + NCCL_API ncclResult_t ncclGetVersion(int* version) { if (version == nullptr) return ncclInvalidArgument; *version = MSCCLPP_VERSION; @@ -355,15 +396,39 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI commPtr->allReduceOPPlan = std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE"))); if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")) - commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); + commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); else - commPtr->smallMessageSizeBoundary = 16 * (1 << 10); + commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10); if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")) - commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); + commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); + else + commPtr->allReduceLargeMessageSizeBoundary = 1 << 20; + + if (getenv("ALLGATHERPKT_IP_JSON_FILE")) + commPtr->allGatherPacketIPPlan = std::make_shared( + mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE"))); + if (getenv("ALLGATHERPKT_OP_JSON_FILE")) + commPtr->allGatherPacketOPPlan = std::make_shared( + mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE"))); + if (getenv("ALLGATHER_IP_JSON_FILE")) + commPtr->allGatherIPPlan = + std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE"))); + if (getenv("ALLGATHER_OP_JSON_FILE")) + commPtr->allGatherOPPlan = + std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE"))); + if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY")) + commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY")); else - commPtr->largeMessageSizeBoundary = 1 << 20; + commPtr->allGatherSmallMessageSizeBoundary = (1 << 10); + if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY")) + commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY")); + else + commPtr->allGatherLargeMessageSizeBoundary = 1 << 20; - if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument; + if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary) + return ncclInvalidArgument; + if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary) + return ncclInvalidArgument; *comm = commPtr; return ncclSuccess; @@ -483,11 +548,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t size_t bytes = count * ncclTypeSize(datatype); int rank = comm->comm->bootstrap()->getRank(); - if (bytes < comm->smallMessageSizeBoundary) { + if (bytes < comm->allReduceSmallMessageSizeBoundary) { return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); } else { std::shared_ptr plan; - if (bytes <= comm->largeMessageSizeBoundary) + if (bytes <= comm->allReduceLargeMessageSizeBoundary) plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan; else { plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan; @@ -533,36 +598,41 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t size_t bytes = sendcount * ncclTypeSize(datatype); if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; - // Declarating variables - size_t recvBytes; - CUdeviceptr recvBasePtr; - MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); - size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; - channelKey recvKey{(void*)recvBasePtr, recvBytes}; int rank = comm->comm->bootstrap()->getRank(); int nRank = comm->comm->bootstrap()->getNranks(); - mscclpp::DeviceHandle* smChannels = nullptr; - auto it = comm->channelOutInfos.find(recvKey); - if (it == comm->channelOutInfos.end()) { - std::vector remoteMemories = setupRemoteMemories( - comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); - std::vector channels = - setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); - std::vector> smChannelDeviceHandles; - std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), - [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; - it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary) + return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); + + std::shared_ptr plan; + if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary) + plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan; + else { + plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan; } - smChannels = it->second.smChannelDeviceHandles.get(); - if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { - CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, - NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); - } else { - CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, - NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); + + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::FLOAT16, + *plan, stream); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes * nRank, + mscclpp::DataType::FLOAT32, *plan, stream); + break; + case ncclBfloat16: + comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes * nRank, + mscclpp::DataType::BFLOAT16, *plan, stream); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::UINT32, + *plan, stream); + break; + default: + return ncclInvalidArgument; } return ncclSuccess;