Skip to content

Commit

Permalink
Add several minor fixes (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
MC952-arch authored Feb 10, 2025
1 parent 0869dcf commit 17cac39
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ FlagCX leverages native collective communications libraries to provide the full
| alltoall ||||||
| group ops |||| ? | ? |

Note that `Homo` and `Hetero` modes refer to communications among homogeneouse and heterogeneous clusters. All supported native collective communication libraries can be referenced through the links below:
Note that `Homo` and `Hetero` modes refer to communications among homogeneous and heterogeneous clusters. All supported native collective communication libraries can be referenced through the links below:

- [NCCL](https://github.com/NVIDIA/nccl), NVIDIA Collective Communications Library.
- [IXCCL](https://www.iluvatar.com/software?fullCode=cpjs-rj-rjz), Iluvatar Corex Collective Communications Library.
Expand Down
9 changes: 8 additions & 1 deletion flagcx/flagcx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,9 @@ flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, size_t coun
// intra-cluster reduce
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->reduce(sendbuff, recvbuff, count, datatype, op, comm->homo_inter_rank, comm->homo_comm, stream));

// inter-cluster sendrecv
deviceAdaptor->streamSynchronize(stream);

// inter-cluster sendrecv
if (comm->homo_inter_rank != comm->homo_rank)
{
if (op == flagcxSum) {
Expand Down Expand Up @@ -670,6 +671,8 @@ flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, size_t coun
}
flagcxGroupEnd();

deviceAdaptor->streamSynchronize(stream);

// intra-cluster allreduce
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->allReduce(recvbuff, recvbuff, count, datatype, op, comm->homo_comm, stream));
}
Expand Down Expand Up @@ -825,6 +828,8 @@ flagcxResult_t flagcxAllGather(const void *sendbuff, void *recvbuff, size_t send
flagcxGroupEnd();
}

deviceAdaptor->streamSynchronize(stream);

// intra-cluster broadcast
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->broadcast(recvbuff, recvbuff, sendcount * comm->nranks, datatype, comm->homo_inter_rank, comm->homo_comm, stream));
}
Expand Down Expand Up @@ -962,6 +967,8 @@ flagcxResult_t flagcxAlltoAll(const void *sendbuff, void *recvbuff, size_t count
}
}
flagcxGroupEnd();

deviceAdaptor->streamSynchronize(stream);
}
}
return flagcxSuccess;
Expand Down
53 changes: 25 additions & 28 deletions plugin/torch/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@

# Create two groups
ranks = list(range(world_size))
FLAGCX_GROUP1 = dist.new_group(ranks=ranks, backend=f"cpu:gloo,{dev_name}:flagcx")
FLAGCX_GROUP2 = dist.new_group(ranks=ranks, backend=f"cpu:gloo,{dev_name}:flagcx")
FLAGCX_GROUP1 = dist.new_group(ranks=ranks, backend=f"{dev_name}:flagcx")
FLAGCX_GROUP2 = dist.new_group(ranks=ranks, backend=f"{dev_name}:flagcx")
ranks_flagcx = dist.get_process_group_ranks(FLAGCX_GROUP1)
print(f"ranks_flagcx: {ranks_flagcx}")

if torch.cuda.is_available():
# Create tensors
torch.cuda.set_device(rank)
torch.cuda.set_device(rank % 8)
x = torch.rand(world_size).cuda()
y = torch.rand(world_size).cuda()
print(f"rank {rank} initial: x = {x}, y = {y}")
Expand All @@ -46,7 +46,6 @@
print(f"rank {rank} after allreduce max with FLAGCX_GROUP1: y = {y}")
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=FLAGCX_GROUP1)
print(f"rank {rank} after allreduce sum with FLAGCX_GROUP1: x = {x}")
dist.barrier(group=FLAGCX_GROUP1)

# Perform send and recv with FLAGCX_GROUP2
for i in range(world_size):
Expand All @@ -55,43 +54,44 @@
op_send = dist.P2POp(dist.isend, x, next_rank, group=FLAGCX_GROUP2)
op_recv = dist.P2POp(dist.irecv, y, prev_rank, group=FLAGCX_GROUP2)
op_list = [op_send, op_recv]
if rank % 2 == 1:
op_list.reverse()
reqs = dist.batch_isend_irecv(op_list)
for req in reqs:
req.wait()
print(f"rank {rank} after batch_isend_irecv with FLAGCX_GROUP2: x = {x}, y = {y}")
if rank % 2 == 0:
dist.send(x, next_rank, group=FLAGCX_GROUP2)
dist.send(x, next_rank, group=FLAGCX_GROUP2)
elif rank % 2 == 1:
dist.recv(y, prev_rank, group=FLAGCX_GROUP2)
elif rank % 2 == 1:
dist.recv(y, prev_rank, group=FLAGCX_GROUP2)
print(f"rank {rank} after send/recv with FLAGCX_GROUP2: x = {x}, y = {y}")
dist.send(x, next_rank, group=FLAGCX_GROUP2)
handle = dist.barrier(group=FLAGCX_GROUP2, async_op=True)
handle.wait()
print(f"rank {rank} after send/recv with FLAGCX_GROUP2: x = {x}, y = {y}")

# Perform allgather with FLAGCX_GROUP1
z = torch.rand(1).cuda()
z[0] = rank
print(f"rank {rank} before _all_gather_base with FLAGCX_GROUP1: z = {z}, y = {y}")
dist._all_gather_base(y, z, group=FLAGCX_GROUP1)
print(f"rank {rank} after _all_gather_base with FLAGCX_GROUP1: z = {z}, y = {y}")
z_list = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_gather with FLAGCX_GROUP1: z = {z}, z_list = {z_list}")
dist.all_gather(z_list, z, group=FLAGCX_GROUP1)
print(f"rank {rank} after all_gather with FLAGCX_GROUP1: z = {z}, z_list = {z_list}")
print(z_list)
all_rank_infos = [None] * world_size
cur_rank_info = {'rank': rank, 'device_type': f"cpu:gloo,{dev_name}:flagcx"}
dist.all_gather_object(all_rank_infos, cur_rank_info)
print(f"rank {rank} after all_gather_object with FLAGCX_GROUP1: all_rank_infos = {all_rank_infos}")
dist.barrier(group=FLAGCX_GROUP1)

# Perform broadcast with FLAGCX_GROUP2
x = torch.rand(world_size).cuda()
print(f"rank {rank} before broadcast with FLAGCX_GROUP2: x = {x}")
dist.broadcast(x, 0, group=FLAGCX_GROUP2)
print(f"rank {rank} after broadcast with FLAGCX_GROUP2: x = {x}")
dist.barrier(group=FLAGCX_GROUP2)
# Perform alltoall with FLAGCX_GROUP2
for i in range(world_size):
y[i] = rank
x[i] = 0
list_y = list(torch.chunk(y, world_size, dim=0))
list_z = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.all_to_all(list_z, list_y, group=FLAGCX_GROUP2)
print(f"rank {rank} after all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")

# Perform reducescatter with FLAGCX_GROUP1
z[0] = 0
Expand All @@ -100,23 +100,20 @@
x_list = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before reduce_scatter with FLAGCX_GROUP1: x_list = {x_list}, z = {z}")
dist.reduce_scatter(z, x_list, op=dist.ReduceOp.SUM, group=FLAGCX_GROUP1)
# dist.barrier(group=FLAGCX_GROUP1)
print(f"rank {rank} after reduce_scatter with FLAGCX_GROUP1: x_list = {x_list}, z = {z}")
dist.barrier(group=FLAGCX_GROUP1)
for i in range(world_size):
x[i] = rank
print(f"rank {rank} before _reduce_scatter_base with FLAGCX_GROUP1: x = {x}, z = {z}")
dist._reduce_scatter_base(z, x, op=dist.ReduceOp.MAX, group=FLAGCX_GROUP1)
# dist.barrier(group=FLAGCX_GROUP1)
print(f"rank {rank} after _reduce_scatter_base with FLAGCX_GROUP1: x = {x}, z = {z}")

# Perform alltoall with FLAGCX_GROUP2
for i in range(world_size):
y[i] = rank
x[i] = 0
list_y = list(torch.chunk(y, world_size, dim=0))
list_z = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.all_to_all(list_z, list_y, group=FLAGCX_GROUP2)
print(f"rank {rank} after all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.barrier(group=FLAGCX_GROUP2)
# Perform broadcast with FLAGCX_GROUP2
x = torch.rand(world_size).cuda()
print(f"rank {rank} before broadcast with FLAGCX_GROUP2: x = {x}")
dist.broadcast(x, 0, group=FLAGCX_GROUP2)
# dist.barrier(group=FLAGCX_GROUP2)
print(f"rank {rank} after broadcast with FLAGCX_GROUP2: x = {x}")

dist.destroy_process_group()
1 change: 1 addition & 0 deletions plugin/torch/include/backend_flagcx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <c10/core/DeviceGuard.h>

#include <pybind11/chrono.h>
#include <unordered_map>
Expand Down
28 changes: 28 additions & 0 deletions plugin/torch/run_hetero.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Check if the debug flag is provided as an argument
if [ "$1" == "debug" ]; then
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=all
echo "NCCL debug information enabled."
else
unset NCCL_DEBUG
unset NCCL_DEBUG_SUBSYS
echo "NCCL debug information disabled."
fi

export FLAGCX_DEBUG=INFO
export FLAGCX_DEBUG_SUBSYS=ALL
export FLAGCX_USENET=mlx5_0
export FLAGCX_USEDEV=1
export GLOO_SOCKET_IFNAME=ibs4
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PATH=/usr/local/corex/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:$PATH
export LD_LIBRARY_PATH=./:/usr/local/corex/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Need to preload customized gloo library specified for FlagCX linkage
# export LD_PRELOAD=/usr/local/lib/libgloo.so
# export LD_PRELOAD=/usr/local/nccl/build/lib/libnccl.so
CMD='torchrun --nproc_per_node 8 --nnodes=2 --node_rank=1 --master_addr="10.31.30.232" --master_port=8122 example.py'

echo $CMD
eval $CMD
17 changes: 7 additions & 10 deletions plugin/torch/src/backend_flagcx.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "backend_flagcx.hpp"
#include <iostream>
#include <c10/core/DeviceGuard.h>

namespace c10d
{
Expand Down Expand Up @@ -138,14 +137,11 @@ namespace c10d

bool WorkFlagcx::wait(std::chrono::milliseconds /* unused */)
{
// TODO: find a solution to block flagcx stream on default stream,
// otherwise we have to call torch distributed ops under a customized stream context.
// if (!coalesced_)
// {
// event_->block(device_id_);
// }
// if (isBarrierOp_)
// For now, we directly call stream sync.
if (!coalesced_)
{
event_->block(device_id_);
}
if (isBarrierOp_)
{
handler_->streamSynchronize(stream_);
}
Expand Down Expand Up @@ -284,6 +280,7 @@ namespace c10d
work->event_->record(stream, device_id);
work->device_id_ = device_id;
work->coalesced_ = false;
work->isBarrierOp_ = true;
// Create a future to track the coalesced operation
work->future_ = c10::make_intrusive<c10::ivalue::Future>(c10::ListType::create(c10::TensorType::get()));
work->future_->markCompleted(c10::IValue(0));
Expand Down Expand Up @@ -822,4 +819,4 @@ namespace c10d
m.def("createBackendFlagcx", &BackendFlagcx::createBackendFlagcx);
}

} // namespace c10d
} // namespace c10d
4 changes: 2 additions & 2 deletions test/perf/test_allgather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAllGather(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAllGather(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_allreduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAllReduce(sendbuff, recvbuff, count, DATATYPE, flagcxSum, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAllReduce(sendbuff, recvbuff, count, DATATYPE, flagcxSum, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_alltoall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAlltoAll(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAlltoAll(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_sendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int main(int argc, char *argv[]){
flagcxRecv(recvbuff, count, DATATYPE, recvPeer, comm, stream);
flagcxGroupEnd();
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

Expand All @@ -84,7 +84,7 @@ int main(int argc, char *argv[]){
flagcxRecv(recvbuff, count, DATATYPE, recvPeer, comm, stream);
flagcxGroupEnd();
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down

0 comments on commit 17cac39

Please sign in to comment.