From f850059a64f011c2d12220ceb7088b5f7865f135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 11 Jun 2024 15:52:37 +0200 Subject: [PATCH] Graph lowering. (#5496) This change implements the "lowering" of the new graph::OpGraph to the old OpGraph used by the current executor. ---- Signed-off-by: Michal Zientkiewicz --- dali/pipeline/executor/executor.h | 4 +- dali/pipeline/executor/executor_impl.h | 10 ++++ dali/pipeline/executor/lowered_graph.cc | 13 ++++ dali/pipeline/executor/lowered_graph.h | 10 ++++ dali/pipeline/executor/op_graph_test.cc | 79 +++++++++++++++++++++++++ dali/pipeline/graph/op_graph2.h | 45 +++++++++----- 6 files changed, 145 insertions(+), 16 deletions(-) diff --git a/dali/pipeline/executor/executor.h b/dali/pipeline/executor/executor.h index 4d68db635df..595222644c1 100644 --- a/dali/pipeline/executor/executor.h +++ b/dali/pipeline/executor/executor.h @@ -40,7 +40,9 @@ using ExecutorMetaMap = std::unordered_map output_names) = 0; + DLL_PUBLIC virtual void Build(const graph::OpGraph &graph) = 0; + // TODO(michalz): Remove + DLL_PUBLIC virtual void Build(OpGraph *graph, std::vector output_names) = 0; DLL_PUBLIC virtual void Init() = 0; DLL_PUBLIC virtual void Run() = 0; DLL_PUBLIC virtual void Prefetch() = 0; diff --git a/dali/pipeline/executor/executor_impl.h b/dali/pipeline/executor/executor_impl.h index 3376faa6817..ccbfc55c5bc 100644 --- a/dali/pipeline/executor/executor_impl.h +++ b/dali/pipeline/executor/executor_impl.h @@ -130,11 +130,21 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy { DLL_PUBLIC int InputFeedCount(std::string_view op_name) override; + DLL_PUBLIC void Build(const graph::OpGraph &graph) override { + lowered_graph_.Lower(graph); + std::vector output_names; + for (std::string_view out : graph.Outputs()) + output_names.emplace_back(out); + Build(&lowered_graph_, std::move(output_names)); + } + DLL_PUBLIC void Build(OpGraph *graph, vector output_names) override; DLL_PUBLIC OperatorBase *GetOperator(std::string_view instance_name) override; protected: + OpGraph lowered_graph_; + DLL_PUBLIC virtual void RunCPU(); DLL_PUBLIC virtual void RunMixed(); DLL_PUBLIC virtual void RunGPU(); diff --git a/dali/pipeline/executor/lowered_graph.cc b/dali/pipeline/executor/lowered_graph.cc index 7ac16509b3f..c4c2b74c09b 100644 --- a/dali/pipeline/executor/lowered_graph.cc +++ b/dali/pipeline/executor/lowered_graph.cc @@ -93,6 +93,19 @@ StorageDevice ParseStorageDevice(const std::string &io_device) { } // namespace +void OpGraph::Lower(const graph::OpGraph &definition) { + if (!op_nodes_.empty() || !tensor_nodes_.empty()) + throw std::logic_error("The target graph must be empty"); + for (auto &node : definition.OpNodes()) { + auto &lowered_op = AddOp(node.spec, node.instance_name); + lowered_op.definition = &node; + } + for (auto &t : tensor_nodes_) { + t.definition = definition.GetData(t.name); + } +} + + OpNode& OpGraph::PlaceNewOp(OpType op_type, const OpSpec &op_spec, std::string instance_name) { op_nodes_.emplace_back(); auto &node = op_nodes_.back(); diff --git a/dali/pipeline/executor/lowered_graph.h b/dali/pipeline/executor/lowered_graph.h index 4fe80a80cd1..de07c05d825 100644 --- a/dali/pipeline/executor/lowered_graph.h +++ b/dali/pipeline/executor/lowered_graph.h @@ -61,6 +61,8 @@ struct OpNode { virtual ~OpNode() = default; OpNode& operator=(const OpNode&) = delete; + const graph::OpNode *definition = nullptr; + OpNode(OpNode &&) = default; OpNode& operator=(OpNode &&) = default; @@ -71,6 +73,7 @@ struct OpNode { std::unique_ptr op; OpNodeId id = -1; + // TODO(michalz): Consider removing the (now) redundant fields and use the definition OpSpec spec; std::set parents, children; @@ -79,6 +82,7 @@ struct OpNode { // To reduce number of allocation of shapes in Setup std::vector output_desc; + // TODO(michalz): Consider removing the (now) redundant fields and use the definition std::string instance_name; OpType op_type = OpType::COUNT; OpPartitionId partition_index = -1; @@ -97,6 +101,10 @@ using consumer_edge_t = TensorMeta; // Second type of graph nodes. struct TensorNode { + // NOTE: TensorNode doesn't define the storage device, but TensorNode is taken from OpSpec + // where it's unambiguously associated with a storage device. + const graph::DataNode *definition = nullptr; + TensorNodeId id; std::string name; producer_edge_t producer; @@ -133,6 +141,8 @@ class DLL_PUBLIC OpGraph { DLL_PUBLIC inline ~OpGraph() = default; + void Lower(const graph::OpGraph &definition); + /** * @brief Adds an op with the input specification to the graph. */ diff --git a/dali/pipeline/executor/op_graph_test.cc b/dali/pipeline/executor/op_graph_test.cc index 5d435f4b91b..e1c9174f586 100644 --- a/dali/pipeline/executor/op_graph_test.cc +++ b/dali/pipeline/executor/op_graph_test.cc @@ -643,5 +643,84 @@ TEST_F(OpGraphTest, TestGetTensorOrigin) { EXPECT_EQ(graph.GetTensorOrigin(4), origin_4); } +inline bool operator==(const dali::TensorMeta &a, const dali::TensorMeta &b) { + return a.index == b.index && a.node == b.node && a.storage_device == b.storage_device; +} + +void CheckEqual(const OpGraph &g1, const OpGraph &g2) { + EXPECT_EQ(g1.NumOp(), g2.NumOp()) << "The number of operator nodes differs."; + EXPECT_EQ(g1.NumTensor(), g2.NumTensor()) << "The number of tensor nodes differs."; + EXPECT_EQ(g1.NumOp(OpType::CPU), g2.NumOp(OpType::CPU)) << "The numberof CPU nodes differs."; + EXPECT_EQ(g1.NumOp(OpType::GPU), g2.NumOp(OpType::GPU)) << "The numberof GPU nodes differs."; + EXPECT_EQ(g1.NumOp(OpType::MIXED), g2.NumOp(OpType::MIXED)) + << "The numberof mixed nodes differs."; + + if (::testing::Test::HasFailure()) + return; + + for (int i = 0; i < g1.NumOp(); i++) { + auto &n1 = g1.Node(i); + auto &n2 = g2.Node(i); + EXPECT_EQ(n1.id, n2.id) << " @ node " << i; + EXPECT_EQ(n1.instance_name, n2.instance_name) << " @ node " << i; + EXPECT_EQ(n1.spec.SchemaName(), n2.spec.SchemaName())<< " @ node " << i; + EXPECT_EQ(n1.children, n2.children) << " @ node " << i; + EXPECT_EQ(n1.parents, n2.parents) << " @ node " << i; + } + for (int i = 0; i < g1.NumTensor(); i++) { + auto &t1 = g1.Tensor(i); + auto &t2 = g2.Tensor(i); + EXPECT_EQ(t1.id, t2.id) << " @ node " << i; + EXPECT_EQ(t1.name, t2.name) << " @ node " << i; + EXPECT_EQ(t1.consumers, t2.consumers) << " @ node " << i; + EXPECT_EQ(t1.producer, t2.producer) << " @ node " << i; + } +} + +TEST_F(OpGraphTest, Lowering) { + OpSpec spec0 = this->PrepareSpec(OpSpec("ExternalSource") + .AddArg("device", "cpu") + .AddArg("device_id", 0) + .AddOutput("data", "cpu")); + + OpSpec spec1 = this->PrepareSpec(OpSpec("Copy") + .AddInput("data", "cpu") + .AddOutput("copy_0_data", "cpu")); + + OpSpec spec2 = this->PrepareSpec(OpSpec("MakeContiguous") + .AddInput("copy_0_data", "cpu") + .AddOutput("contiguous_data", "cpu")); + + OpSpec spec3 = this->PrepareSpec(OpSpec("PassthroughOp") + .AddInput("contiguous_data", "cpu") + .AddOutput("passthrough_data", "cpu")); + + OpSpec spec4 = this->PrepareSpec(OpSpec("Copy") + .AddInput("passthrough_data", "cpu") + .AddOutput("copy_1_data", "cpu")); + + graph::OpGraph::Builder b; + // This is the same graph as in TestGetTensorOrigin, but the topological order is not maintained. + b.Add("Copy1", spec4); // tensor node 4 + b.Add("ExternalSource", spec0); // tensor node 0 + b.Add("MakeContiguous", spec2); // tensor node 2 + b.Add("Passthrough", spec3); // tensor node 3 + b.Add("Copy0", spec1); // tensor node 1 + b.AddOutput("copy_1_data_cpu"); + + auto def = std::move(b).GetGraph(true); + OpGraph lowered; + lowered.Lower(def); + + OpGraph handmade; + handmade.AddOp(spec0, "ExternalSource"); // tensor node 0 + handmade.AddOp(spec1, "Copy0"); // tensor node 1 + handmade.AddOp(spec2, "MakeContiguous"); // tensor node 2 + handmade.AddOp(spec3, "Passthrough"); // tensor node 3 + handmade.AddOp(spec4, "Copy1"); + + CheckEqual(lowered, handmade); +} + } // namespace dali diff --git a/dali/pipeline/graph/op_graph2.h b/dali/pipeline/graph/op_graph2.h index 00b26ab48dc..b42c550edb8 100644 --- a/dali/pipeline/graph/op_graph2.h +++ b/dali/pipeline/graph/op_graph2.h @@ -130,25 +130,24 @@ class DLL_PUBLIC OpGraph { return data_nodes_; } + /** Returns an OpNode with a matching instance name or nullptr. */ OpNode *GetOp(std::string_view instance_name) { - auto it = name2op_.find(instance_name); - if (it != name2op_.end()) - return &*it->second; - else - return nullptr; + return GetOpImpl(instance_name); } - /** Returns a DataNode with a matching name or nullptr. - * - * @param data_node_name - * @return DataNode* - */ + /** Returns an OpNode with a matching instance name or nullptr. */ + const OpNode *GetOp(std::string_view instance_name) const { + return GetOpImpl(instance_name); + } + + /** Returns a DataNode with a matching name or nullptr. */ DataNode *GetData(std::string_view data_node_name) { - auto it = name2data_.find(data_node_name); - if (it != name2data_.end()) - return &*it->second; - else - return nullptr; + return GetDataImpl(data_node_name); + } + + /** Returns a DataNode with a matching name or nullptr. */ + const DataNode *GetData(std::string_view data_node_name) const { + return GetDataImpl(data_node_name); } /** Sorts the graph topologically and removes entries that do not contribute to essential nodes. @@ -200,6 +199,22 @@ class DLL_PUBLIC OpGraph { } private: + OpNode *GetOpImpl(std::string_view instance_name) const { + auto it = name2op_.find(instance_name); + if (it != name2op_.end()) + return &*it->second; + else + return nullptr; + } + + DataNode *GetDataImpl(std::string_view data_node_name) const { + auto it = name2data_.find(data_node_name); + if (it != name2data_.end()) + return &*it->second; + else + return nullptr; + } + void RemoveDataNodeReferences(OpNode &op); OpNodeList op_nodes_;