Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev refactor xccl primitive #10613

Merged
merged 32 commits into from
Jan 26, 2025
Merged

Dev refactor xccl primitive #10613

merged 32 commits into from
Jan 26, 2025

Conversation

Flowingsun007
Copy link
Contributor

@Flowingsun007 Flowingsun007 commented Jan 3, 2025

  • 推进解耦cuda nccl和oneflow的深度绑定,重构EagerCclCommMgr及ccl::Comm等模块,方便在kernel里直接使用设备无关的(类似primitive)的ccl通信调用实现,替代直接使用nccl apis,推进后续多设备兼容。
  • 后续支持/适配不同设备(cuda/npu/xpu等)时,原则上在kernel以及其他调用通信api的代码处,原则上不应该直接调用类似nccl这样的设备耦合的通信apis,而应该直接使用oneflow::ccl::Send/Recv/AllReduce/.... 等父类api(具体位于oneflow/user/kernels/collective_communication/include目录下)并提供子类实现
  • 后续各设备需继承oneflow::ccl通信apis实现自己的子类通信apis。
    • 如cuda设备需要通过nccl api实现oneflow::ccl::CudaSend/CudaRecv/CudaAllReduce....等。
    • npu设备需要通过hccl api实现oneflow::ccl::NpuSend/NpuRecv/NpuAllReduce等

@ShawnXuan
Copy link
Collaborator

是否替换 AllToAll

1. oneflow/core/job/collective_boxing/nccl_executor_backend.cu❌

链接: https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job/collective_boxing/nccl_executor_backend.cu#L367-L383

代码:

            for (int64_t j = 0; j < num_ranks; ++j) {
              OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
                                         reinterpret_cast<const char*>(send_buff) + j * chunk_size),
                                     elem_per_chunk, nccl_data_type, j, comm,
                                     stream_ctx->stream()));
              OF_NCCL_CHECK(ncclRecv(
                  reinterpret_cast<void*>(reinterpret_cast<char*>(recv_buff) + j * chunk_size),
                  elem_per_chunk, nccl_data_type, j, comm, stream_ctx->stream()));
            }

cuda中可以不替换,要求设备里也实现了 xccl_executor_backend.cu

  • oneflow-npu hccl_executor_backend.cu 里面 op_type == OpType::kOpTypeAll2All 部分可以替换为HcclAllToAll

2. oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp❓

链接:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp#L122-L133

代码:

  OF_NCCL_CHECK(ncclGroupStart());
  for (int64_t i = 0; i < parallel_num; ++i) {
    if (this->has_input() && send_elem_cnts.at(i) != 0) {
      OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i,
                             comm, cuda_stream));
    }
    if (this->has_output() && recv_elem_cnts.at(i) != 0) {
      OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type),
                             i, comm, cuda_stream));
    }
  }
  OF_NCCL_CHECK(ncclGroupEnd());

需要替换为AllToAll,不过有几个细节需要研究一下:

  • has_input(), has_output()
  • send_elem_cnts, recv_elem_cnts

3. oneflow/user/kernels/eager_nccl_s2s_kernel.cu✅

链接:oneflow/user/kernels/eager_nccl_s2s_kernel.cu

代码:

      OF_NCCL_CHECK(ncclGroupStart());
      const int64_t elem_per_chunk = elem_cnt / num_ranks;
      const int64_t chunk_size = elem_per_chunk * dtype_size;
      for (int64_t j = 0; j < num_ranks; ++j) {
        OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
                                   reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
                               elem_per_chunk, GetNcclDataType(in->data_type()), j,
                               kernel_cache->comm(),
                               ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
        OF_NCCL_CHECK(ncclRecv(
            reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
            elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_cache->comm(),
            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
      }
      OF_NCCL_CHECK(ncclGroupEnd());

可以替换,chunk是均匀分配的

4. oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp ✅

链接:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp#L345-L359

代码:

      OF_NCCL_CHECK(ncclGroupStart());
      const int64_t elem_per_chunk = elem_cnt / num_ranks;
      const int64_t chunk_size = elem_per_chunk * dtype_size;
      for (int64_t j = 0; j < num_ranks; ++j) {
        OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
                                   reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
                               elem_per_chunk, GetNcclDataType(in->data_type()), j,
                               kernel_state->comm(),
                               ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
        OF_NCCL_CHECK(ncclRecv(
            reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
            elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(),
            ctx->stream()->As<ep::CudaStream>()->cuda_stream()));
      }
      OF_NCCL_CHECK(ncclGroupEnd());

可以替换,chunk是均匀分配的

5. oneflow/user/kernels/nccl_logical_fusion_kernel.cpp ✅

这个文件里有两处:

可以替换

6. oneflow/user/kernels/nccl_logical_kernels.cpp ✅

可以替换

7. oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp ❓

代码:

  for (int64_t i = 0; i < parallel_num; ++i) {
    if (send_elem_cnts.at(i) != 0) {
      LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i;
      OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i,
                             comm, cuda_stream));
    }
    if (recv_elem_cnts.at(i) != 0) {
      LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i;
      OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type),
                             i, comm, cuda_stream));
    }
  }

需要考虑一下 send_elem_cnts recv_elem_cnts

8. oneflow/user/kernels/one_embedding_data_shuffle.cuh❌

  for (int64_t i = 0; i < parallel_num; ++i) {
    OF_NCCL_CHECK(ncclSend(send_data + send_offsets.at(i), send_elem_cnt.at(i), nccl_data_type, i,
                           comm, cuda_stream));
    OF_NCCL_CHECK(ncclRecv(recv_data + recv_offsets.at(i), recv_elem_cnt.at(i), nccl_data_type, i,
                           comm, cuda_stream));
  }

先不替换,先不支持 one_embedding

Comment on lines 33 to 42
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) {
ccl::CclComm ccl_comm{};
return ccl_comm;
}
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {
ccl::CclComm ccl_comm{};
return ccl_comm;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个行为并不适合放在抽象类中,对于cpu相关的实现就不适用

Copy link
Contributor Author

@Flowingsun007 Flowingsun007 Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,那这块定义成纯虚的怎样?cpu的目前看也没有子类的CommMgr实现(后续如果cpu需要实现,在子类的方法中直接UNIMPLEMENTED()感觉也行?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_set也是从paralledesc来的,这里接口的形式可以改变一下,输入是paralledesc,输出是CommunicationContext

Comment on lines 44 to 63
// abstruct base class for comm
class CommBase {
public:
virtual ~CommBase() = default;

// return impl of comm
virtual void* getComm() = 0;
};

class CclComm {
public:
CclComm() {}
explicit CclComm(std::shared_ptr<CommBase> comm) : comm_(std::move(comm)) {}

void* getComm() { return comm_->getComm(); }

private:
std::shared_ptr<CommBase> comm_{};
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个东西是不是和 CommunicationContext 的作用重复了,没有必要再搞一套

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是有点区别,CommunicationContext我理解是根据ParallelDesc,Init对应的的comm实现;CclComm则是直接传入已创建好的comm对象,包了一层,提供一个统一的getComm方法

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是有点区别,CommunicationContext我理解是根据ParallelDesc,Init对应的的comm实现;CclComm则是直接传入已创建好的comm对象,包了一层,提供一个统一的getComm方法

创建好的comm对象,这个东西是不是也是根据ParallelDesc创建出来的呀,本质还是一个东西吧。

Comment on lines 33 to 36
virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0;

virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
CclComm ccl_comm) const = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样,launch接口没有必要定义两套

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要是不太好改动之前的那套😂,所以新加了一个(因为之前的comm比较隐晦,像这里的cuda的recv实现,comm是通过const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);拿到的;新加的这个提供一个直接的CclComm,感觉上比较直观一点

@Flowingsun007 Flowingsun007 enabled auto-merge (squash) January 14, 2025 01:35
@Flowingsun007 Flowingsun007 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot January 14, 2025 01:35
Copy link
Contributor

oneflow/user/kernels/eager_nccl_s2s_kernel.cu Outdated Show resolved Hide resolved
@@ -39,9 +42,9 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
}
~NcclLogical2DSameDim0KernelCommState() override = default;

ncclComm_t comm() {
ccl::CclComm ccl_comm() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ccl::CclComm ccl_comm() {
const ccl::CclComm& ccl_comm() const {

oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp Outdated Show resolved Hide resolved
oneflow/user/kernels/nccl_logical_kernels.cpp Outdated Show resolved Hide resolved
Comment on lines 33 to 36
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) = 0;
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) = 0;
Copy link
Contributor

@clackhan clackhan Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) = 0;
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) = 0;
virtual ccl::CclComm GetCclCommForParallelDesc(
const ParallelDesc& parallel_desc) = 0;
virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(
const ParallelDesc& parallel_desc, const std::string& stream_name) = 0;

这里的参数应该是parallel_desc,device_set是cuda需要的形式,parallel_desc 构建device_set的过程应该放到派生类中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果统一传ParallelDesc,那对于1D/2D的情况不太好处理吧?😂譬如:oneflow/user/kernels/nccl_logical_fusion_kernel.cpp这种,device_set的创建即和hierarchy.NumAxes()相关,还需要comm_key_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

Copy link
Contributor

Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.9ms (= 4388.5ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.6ms (= 5759.7ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.31 (= 57.6ms / 43.9ms)

OneFlow resnet50 time: 26.5ms (= 2653.2ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.4ms (= 3740.6ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.41 (= 37.4ms / 26.5ms)

OneFlow resnet50 time: 17.6ms (= 3528.8ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.1ms (= 7013.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.99 (= 35.1ms / 17.6ms)

OneFlow resnet50 time: 18.0ms (= 3606.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 31.7ms (= 6336.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.76 (= 31.7ms / 18.0ms)

OneFlow resnet50 time: 17.7ms (= 3539.6ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 28.2ms (= 5637.4ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.59 (= 28.2ms / 17.7ms)

OneFlow swin dataloader time: 0.200s (= 40.046s / 200, num_workers=1)
PyTorch swin dataloader time: 0.128s (= 25.601s / 200, num_workers=1)
Relative speed: 0.639 (= 0.128s / 0.200s)

OneFlow swin dataloader time: 0.055s (= 11.057s / 200, num_workers=4)
PyTorch swin dataloader time: 0.032s (= 6.497s / 200, num_workers=4)
Relative speed: 0.588 (= 0.032s / 0.055s)

OneFlow swin dataloader time: 0.031s (= 6.121s / 200, num_workers=8)
PyTorch swin dataloader time: 0.016s (= 3.271s / 200, num_workers=8)
Relative speed: 0.534 (= 0.016s / 0.031s)

❌ OneFlow resnet50 time: 50.0ms (= 5001.1ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.3ms (= 6433.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.29 (= 64.3ms / 50.0ms)

OneFlow resnet50 time: 36.7ms (= 3675.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 47.3ms (= 4733.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.29 (= 47.3ms / 36.7ms)

OneFlow resnet50 time: 27.9ms (= 5573.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 39.0ms (= 7801.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.40 (= 39.0ms / 27.9ms)

OneFlow resnet50 time: 25.4ms (= 5084.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.6ms (= 8311.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.63 (= 41.6ms / 25.4ms)

OneFlow resnet50 time: 25.0ms (= 4992.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 35.9ms (= 7185.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.44 (= 35.9ms / 25.0ms)

@Flowingsun007 Flowingsun007 merged commit 26a393c into master Jan 26, 2025
21 checks passed
@Flowingsun007 Flowingsun007 deleted the dev_refactor_xccl_primitive branch January 26, 2025 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants