Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Dec 2, 2024
1 parent 0faade2 commit 4d29958
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
9 changes: 5 additions & 4 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("COMMUNICATION_COLLECTIVE_DIR")) {
std::string collectiveDir = getenv("COMMUNICATION_COLLECTIVE_DIR");
if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) {
std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR");
for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) {
if (entry.is_regular_file()) {
std::string filename = entry.path().filename().string();
Expand Down Expand Up @@ -583,9 +583,10 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t

std::vector<executionPlanInstance>& plans = comm->executionPlans["allgather"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
bool inPlace = sendbuff == recvbuff;
void* basePtr = (char*)sendbuff - rank * bytes;
bool inPlace = basePtr == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
if (bytes * nRank >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
Expand Down
16 changes: 11 additions & 5 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,15 @@ using json = nlohmann::json;
ExecutionPlan::Impl::Impl(const std::string name, const std::string planPath)
: name(name), planPath(planPath), isUsingPacket(false) {}

ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) {}
ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
this->name = obj["name"];
this->collective = obj["collective"];
this->isInPlace = obj["inplace"];
this->minMessageSize = obj.value("min_message_size", 0);
this->maxMessageSize = obj.value("max_message_size", std::numeric_limits<uint64_t>::max());
}

std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const {
auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; };
Expand Down Expand Up @@ -186,9 +194,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name.empty()) {
this->name = obj["name"];
} else if (this->name != obj["name"]) {
if (this->name != obj["name"]) {
throw Error("Plan name does not match", ErrorCode::ExecutorError);
}
this->collective = obj["collective"];
Expand All @@ -201,7 +207,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->minMessageSize = obj.value("min_message_size", 0);
this->maxMessageSize = obj.value("max_message_size", std::numeric_limits<uint64_t>::max());
this->isInPlace = obj["in_place"];
this->isInPlace = obj["inplace"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand Down

0 comments on commit 4d29958

Please sign in to comment.