Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add several minor fixes #32

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ FlagCX leverages native collective communications libraries to provide the full
| alltoall | ✓ | ✓ | ✓ | ✓ | ✓ |
| group ops | ✓ | ✓ | ✓ | ? | ? |

Note that `Homo` and `Hetero` modes refer to communications among homogeneouse and heterogeneous clusters. All supported native collective communication libraries can be referenced through the links below:
Note that `Homo` and `Hetero` modes refer to communications among homogeneous and heterogeneous clusters. All supported native collective communication libraries can be referenced through the links below:

- [NCCL](https://github.com/NVIDIA/nccl), NVIDIA Collective Communications Library.
- [IXCCL](https://www.iluvatar.com/software?fullCode=cpjs-rj-rjz), Iluvatar Corex Collective Communications Library.
Expand Down
9 changes: 8 additions & 1 deletion flagcx/flagcx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,9 @@ flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, size_t coun
// intra-cluster reduce
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->reduce(sendbuff, recvbuff, count, datatype, op, comm->homo_inter_rank, comm->homo_comm, stream));

// inter-cluster sendrecv
deviceAdaptor->streamSynchronize(stream);

// inter-cluster sendrecv
if (comm->homo_inter_rank != comm->homo_rank)
{
if (op == flagcxSum) {
Expand Down Expand Up @@ -670,6 +671,8 @@ flagcxResult_t flagcxAllReduce(const void *sendbuff, void *recvbuff, size_t coun
}
flagcxGroupEnd();

deviceAdaptor->streamSynchronize(stream);

// intra-cluster allreduce
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->allReduce(recvbuff, recvbuff, count, datatype, op, comm->homo_comm, stream));
}
Expand Down Expand Up @@ -825,6 +828,8 @@ flagcxResult_t flagcxAllGather(const void *sendbuff, void *recvbuff, size_t send
flagcxGroupEnd();
}

deviceAdaptor->streamSynchronize(stream);

// intra-cluster broadcast
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->broadcast(recvbuff, recvbuff, sendcount * comm->nranks, datatype, comm->homo_inter_rank, comm->homo_comm, stream));
}
Expand Down Expand Up @@ -962,6 +967,8 @@ flagcxResult_t flagcxAlltoAll(const void *sendbuff, void *recvbuff, size_t count
}
}
flagcxGroupEnd();

deviceAdaptor->streamSynchronize(stream);
}
}
return flagcxSuccess;
Expand Down
53 changes: 25 additions & 28 deletions plugin/torch/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@

# Create two groups
ranks = list(range(world_size))
FLAGCX_GROUP1 = dist.new_group(ranks=ranks, backend=f"cpu:gloo,{dev_name}:flagcx")
FLAGCX_GROUP2 = dist.new_group(ranks=ranks, backend=f"cpu:gloo,{dev_name}:flagcx")
FLAGCX_GROUP1 = dist.new_group(ranks=ranks, backend=f"{dev_name}:flagcx")
FLAGCX_GROUP2 = dist.new_group(ranks=ranks, backend=f"{dev_name}:flagcx")
ranks_flagcx = dist.get_process_group_ranks(FLAGCX_GROUP1)
print(f"ranks_flagcx: {ranks_flagcx}")

if torch.cuda.is_available():
# Create tensors
torch.cuda.set_device(rank)
torch.cuda.set_device(rank % 8)
x = torch.rand(world_size).cuda()
y = torch.rand(world_size).cuda()
print(f"rank {rank} initial: x = {x}, y = {y}")
Expand All @@ -46,7 +46,6 @@
print(f"rank {rank} after allreduce max with FLAGCX_GROUP1: y = {y}")
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=FLAGCX_GROUP1)
print(f"rank {rank} after allreduce sum with FLAGCX_GROUP1: x = {x}")
dist.barrier(group=FLAGCX_GROUP1)

# Perform send and recv with FLAGCX_GROUP2
for i in range(world_size):
Expand All @@ -55,43 +54,44 @@
op_send = dist.P2POp(dist.isend, x, next_rank, group=FLAGCX_GROUP2)
op_recv = dist.P2POp(dist.irecv, y, prev_rank, group=FLAGCX_GROUP2)
op_list = [op_send, op_recv]
if rank % 2 == 1:
op_list.reverse()
reqs = dist.batch_isend_irecv(op_list)
for req in reqs:
req.wait()
print(f"rank {rank} after batch_isend_irecv with FLAGCX_GROUP2: x = {x}, y = {y}")
if rank % 2 == 0:
dist.send(x, next_rank, group=FLAGCX_GROUP2)
dist.send(x, next_rank, group=FLAGCX_GROUP2)
elif rank % 2 == 1:
dist.recv(y, prev_rank, group=FLAGCX_GROUP2)
elif rank % 2 == 1:
dist.recv(y, prev_rank, group=FLAGCX_GROUP2)
print(f"rank {rank} after send/recv with FLAGCX_GROUP2: x = {x}, y = {y}")
dist.send(x, next_rank, group=FLAGCX_GROUP2)
handle = dist.barrier(group=FLAGCX_GROUP2, async_op=True)
handle.wait()
print(f"rank {rank} after send/recv with FLAGCX_GROUP2: x = {x}, y = {y}")

# Perform allgather with FLAGCX_GROUP1
z = torch.rand(1).cuda()
z[0] = rank
print(f"rank {rank} before _all_gather_base with FLAGCX_GROUP1: z = {z}, y = {y}")
dist._all_gather_base(y, z, group=FLAGCX_GROUP1)
print(f"rank {rank} after _all_gather_base with FLAGCX_GROUP1: z = {z}, y = {y}")
z_list = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_gather with FLAGCX_GROUP1: z = {z}, z_list = {z_list}")
dist.all_gather(z_list, z, group=FLAGCX_GROUP1)
print(f"rank {rank} after all_gather with FLAGCX_GROUP1: z = {z}, z_list = {z_list}")
print(z_list)
all_rank_infos = [None] * world_size
cur_rank_info = {'rank': rank, 'device_type': f"cpu:gloo,{dev_name}:flagcx"}
dist.all_gather_object(all_rank_infos, cur_rank_info)
print(f"rank {rank} after all_gather_object with FLAGCX_GROUP1: all_rank_infos = {all_rank_infos}")
dist.barrier(group=FLAGCX_GROUP1)

# Perform broadcast with FLAGCX_GROUP2
x = torch.rand(world_size).cuda()
print(f"rank {rank} before broadcast with FLAGCX_GROUP2: x = {x}")
dist.broadcast(x, 0, group=FLAGCX_GROUP2)
print(f"rank {rank} after broadcast with FLAGCX_GROUP2: x = {x}")
dist.barrier(group=FLAGCX_GROUP2)
# Perform alltoall with FLAGCX_GROUP2
for i in range(world_size):
y[i] = rank
x[i] = 0
list_y = list(torch.chunk(y, world_size, dim=0))
list_z = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.all_to_all(list_z, list_y, group=FLAGCX_GROUP2)
print(f"rank {rank} after all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")

# Perform reducescatter with FLAGCX_GROUP1
z[0] = 0
Expand All @@ -100,23 +100,20 @@
x_list = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before reduce_scatter with FLAGCX_GROUP1: x_list = {x_list}, z = {z}")
dist.reduce_scatter(z, x_list, op=dist.ReduceOp.SUM, group=FLAGCX_GROUP1)
# dist.barrier(group=FLAGCX_GROUP1)
print(f"rank {rank} after reduce_scatter with FLAGCX_GROUP1: x_list = {x_list}, z = {z}")
dist.barrier(group=FLAGCX_GROUP1)
for i in range(world_size):
x[i] = rank
print(f"rank {rank} before _reduce_scatter_base with FLAGCX_GROUP1: x = {x}, z = {z}")
dist._reduce_scatter_base(z, x, op=dist.ReduceOp.MAX, group=FLAGCX_GROUP1)
# dist.barrier(group=FLAGCX_GROUP1)
print(f"rank {rank} after _reduce_scatter_base with FLAGCX_GROUP1: x = {x}, z = {z}")

# Perform alltoall with FLAGCX_GROUP2
for i in range(world_size):
y[i] = rank
x[i] = 0
list_y = list(torch.chunk(y, world_size, dim=0))
list_z = list(torch.chunk(x, world_size, dim=0))
print(f"rank {rank} before all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.all_to_all(list_z, list_y, group=FLAGCX_GROUP2)
print(f"rank {rank} after all_to_all with FLAGCX_GROUP2: list_y = {list_y}, list_z = {list_z}")
dist.barrier(group=FLAGCX_GROUP2)
# Perform broadcast with FLAGCX_GROUP2
x = torch.rand(world_size).cuda()
print(f"rank {rank} before broadcast with FLAGCX_GROUP2: x = {x}")
dist.broadcast(x, 0, group=FLAGCX_GROUP2)
# dist.barrier(group=FLAGCX_GROUP2)
print(f"rank {rank} after broadcast with FLAGCX_GROUP2: x = {x}")

dist.destroy_process_group()
1 change: 1 addition & 0 deletions plugin/torch/include/backend_flagcx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <c10/core/DeviceGuard.h>

#include <pybind11/chrono.h>
#include <unordered_map>
Expand Down
28 changes: 28 additions & 0 deletions plugin/torch/run_hetero.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Check if the debug flag is provided as an argument
if [ "$1" == "debug" ]; then
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=all
echo "NCCL debug information enabled."
else
unset NCCL_DEBUG
unset NCCL_DEBUG_SUBSYS
echo "NCCL debug information disabled."
fi

export FLAGCX_DEBUG=INFO
export FLAGCX_DEBUG_SUBSYS=ALL
export FLAGCX_USENET=mlx5_0
export FLAGCX_USEDEV=1
export GLOO_SOCKET_IFNAME=ibs4
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PATH=/usr/local/corex/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:$PATH
export LD_LIBRARY_PATH=./:/usr/local/corex/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Need to preload customized gloo library specified for FlagCX linkage
# export LD_PRELOAD=/usr/local/lib/libgloo.so
# export LD_PRELOAD=/usr/local/nccl/build/lib/libnccl.so
CMD='torchrun --nproc_per_node 8 --nnodes=2 --node_rank=1 --master_addr="10.31.30.232" --master_port=8122 example.py'

echo $CMD
eval $CMD
17 changes: 7 additions & 10 deletions plugin/torch/src/backend_flagcx.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "backend_flagcx.hpp"
#include <iostream>
#include <c10/core/DeviceGuard.h>

namespace c10d
{
Expand Down Expand Up @@ -138,14 +137,11 @@ namespace c10d

bool WorkFlagcx::wait(std::chrono::milliseconds /* unused */)
{
// TODO: find a solution to block flagcx stream on default stream,
// otherwise we have to call torch distributed ops under a customized stream context.
// if (!coalesced_)
// {
// event_->block(device_id_);
// }
// if (isBarrierOp_)
// For now, we directly call stream sync.
if (!coalesced_)
{
event_->block(device_id_);
}
if (isBarrierOp_)
Comment on lines +140 to +144
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this code working now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this code working now?

It seems that a sync stream call is required after flagcxHeteroSend/Recv ops. Such synchronization is added in all C2C algorithms for the moment. Here we adopt the same way as ProcessGroupNCCL does, blocking current stream on the flagcx stream.

{
handler_->streamSynchronize(stream_);
}
Expand Down Expand Up @@ -284,6 +280,7 @@ namespace c10d
work->event_->record(stream, device_id);
work->device_id_ = device_id;
work->coalesced_ = false;
work->isBarrierOp_ = true;
// Create a future to track the coalesced operation
work->future_ = c10::make_intrusive<c10::ivalue::Future>(c10::ListType::create(c10::TensorType::get()));
work->future_->markCompleted(c10::IValue(0));
Expand Down Expand Up @@ -822,4 +819,4 @@ namespace c10d
m.def("createBackendFlagcx", &BackendFlagcx::createBackendFlagcx);
}

} // namespace c10d
} // namespace c10d
4 changes: 2 additions & 2 deletions test/perf/test_allgather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAllGather(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAllGather(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_allreduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAllReduce(sendbuff, recvbuff, count, DATATYPE, flagcxSum, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAllReduce(sendbuff, recvbuff, count, DATATYPE, flagcxSum, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_alltoall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ int main(int argc, char *argv[]){
for(int i=0;i<num_warmup_iters;i++){
flagcxAlltoAll(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

tim.reset();
for(int i=0;i<num_iters;i++){
flagcxAlltoAll(sendbuff, recvbuff, count / totalProcs, DATATYPE, comm, stream);
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down
4 changes: 2 additions & 2 deletions test/perf/test_sendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int main(int argc, char *argv[]){
flagcxRecv(recvbuff, count, DATATYPE, recvPeer, comm, stream);
flagcxGroupEnd();
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

MPI_Barrier(MPI_COMM_WORLD);

Expand All @@ -84,7 +84,7 @@ int main(int argc, char *argv[]){
flagcxRecv(recvbuff, count, DATATYPE, recvPeer, comm, stream);
flagcxGroupEnd();
}
flagcxBarrier(comm, stream);
devHandle->streamSynchronize(stream);

double elapsed_time = tim.elapsed() / num_iters;
double base_bw = (double)(size) / 1.0E9 / elapsed_time;
Expand Down