From 998dc17aa6ac62a18dbe894337ee288c2434fa86 Mon Sep 17 00:00:00 2001 From: Luyang Date: Thu, 26 Dec 2024 15:11:39 +0800 Subject: [PATCH] fix device_type2sub_tsk_gph_builder_ (#10563) --- oneflow/core/graph/task_graph.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 10280a8dfe5..082b45ba845 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -880,14 +880,16 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { if (device_type != DeviceType::kCPU && device_type2sub_tsk_gph_builder_.find(device_type) != device_type2sub_tsk_gph_builder_.end()) { - status = CHECK_JUST( // NOLINT + auto maybe_status = // NOLINT device_type2sub_tsk_gph_builder_ // NOLINT .at(device_type) // NOLINT ->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT - *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT - } else { + *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT + if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); } + } + if (!status) { status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, @@ -1052,6 +1054,12 @@ Maybe GlobalTaskGraph::Init() { OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); + // Register the corresponding task graph builder based on the device type and store them to map + const auto* global_device_type_create_sub_tsk_gph_builder_fn = + GlobalDeviceType2CreateSubTskGphBuilderFn(); + for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { + device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); + } hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); HashMap> op_node2sorted_comp_tasks; @@ -1088,6 +1096,13 @@ Maybe BoxingTaskGraph::Init( OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); + // Register the corresponding task graph builder based on the device type and store them to map + const auto* global_device_type_create_sub_tsk_gph_builder_fn = + GlobalDeviceType2CreateSubTskGphBuilderFn(); + for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { + device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); + } + hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {