Skip to content

Commit

Permalink
fix device_type2sub_tsk_gph_builder_ (#10563)
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 authored Dec 26, 2024
1 parent f7fa76f commit 998dc17
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,14 +880,16 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
if (device_type != DeviceType::kCPU
&& device_type2sub_tsk_gph_builder_.find(device_type)
!= device_type2sub_tsk_gph_builder_.end()) {
status = CHECK_JUST( // NOLINT
auto maybe_status = // NOLINT
device_type2sub_tsk_gph_builder_ // NOLINT
.at(device_type) // NOLINT
->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT
&sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT
blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT
} else {
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT
if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); }
}
if (!status) {
status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(
sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp,
Expand Down Expand Up @@ -1052,6 +1054,12 @@ Maybe<void> GlobalTaskGraph::Init() {
OpGraph* op_graph = Singleton<OpGraph>::Get();
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
// Register the corresponding task graph builder based on the device type and store them to map
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
GlobalDeviceType2CreateSubTskGphBuilderFn();
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
}
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;

Expand Down Expand Up @@ -1088,6 +1096,13 @@ Maybe<void> BoxingTaskGraph::Init(
OpGraph* op_graph = Singleton<OpGraph>::Get();
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
// Register the corresponding task graph builder based on the device type and store them to map
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
GlobalDeviceType2CreateSubTskGphBuilderFn();
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
}

hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());

const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {
Expand Down

0 comments on commit 998dc17

Please sign in to comment.