Skip to content

Commit

Permalink
Graph lowering. (NVIDIA#5496)
Browse files Browse the repository at this point in the history
This change implements the "lowering" of the new graph::OpGraph to the old OpGraph used by the current executor.

----

Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient authored Jun 11, 2024
1 parent d265815 commit f850059
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 16 deletions.
4 changes: 3 additions & 1 deletion dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using ExecutorMetaMap = std::unordered_map<std::string, std::vector<ExecutorMeta
class DLL_PUBLIC ExecutorBase {
public:
DLL_PUBLIC virtual ~ExecutorBase() {}
DLL_PUBLIC virtual void Build(OpGraph *graph, vector<string> output_names) = 0;
DLL_PUBLIC virtual void Build(const graph::OpGraph &graph) = 0;
// TODO(michalz): Remove
DLL_PUBLIC virtual void Build(OpGraph *graph, std::vector<std::string> output_names) = 0;
DLL_PUBLIC virtual void Init() = 0;
DLL_PUBLIC virtual void Run() = 0;
DLL_PUBLIC virtual void Prefetch() = 0;
Expand Down
10 changes: 10 additions & 0 deletions dali/pipeline/executor/executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,21 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {

DLL_PUBLIC int InputFeedCount(std::string_view op_name) override;

DLL_PUBLIC void Build(const graph::OpGraph &graph) override {
lowered_graph_.Lower(graph);
std::vector<std::string> output_names;
for (std::string_view out : graph.Outputs())
output_names.emplace_back(out);
Build(&lowered_graph_, std::move(output_names));
}

DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override;

DLL_PUBLIC OperatorBase *GetOperator(std::string_view instance_name) override;

protected:
OpGraph lowered_graph_;

DLL_PUBLIC virtual void RunCPU();
DLL_PUBLIC virtual void RunMixed();
DLL_PUBLIC virtual void RunGPU();
Expand Down
13 changes: 13 additions & 0 deletions dali/pipeline/executor/lowered_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ StorageDevice ParseStorageDevice(const std::string &io_device) {

} // namespace

void OpGraph::Lower(const graph::OpGraph &definition) {
if (!op_nodes_.empty() || !tensor_nodes_.empty())
throw std::logic_error("The target graph must be empty");
for (auto &node : definition.OpNodes()) {
auto &lowered_op = AddOp(node.spec, node.instance_name);
lowered_op.definition = &node;
}
for (auto &t : tensor_nodes_) {
t.definition = definition.GetData(t.name);
}
}


OpNode& OpGraph::PlaceNewOp(OpType op_type, const OpSpec &op_spec, std::string instance_name) {
op_nodes_.emplace_back();
auto &node = op_nodes_.back();
Expand Down
10 changes: 10 additions & 0 deletions dali/pipeline/executor/lowered_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ struct OpNode {
virtual ~OpNode() = default;
OpNode& operator=(const OpNode&) = delete;

const graph::OpNode *definition = nullptr;

OpNode(OpNode &&) = default;
OpNode& operator=(OpNode &&) = default;

Expand All @@ -71,6 +73,7 @@ struct OpNode {

std::unique_ptr<OperatorBase> op;
OpNodeId id = -1;
// TODO(michalz): Consider removing the (now) redundant fields and use the definition
OpSpec spec;
std::set<OpNodeId> parents, children;

Expand All @@ -79,6 +82,7 @@ struct OpNode {
// To reduce number of allocation of shapes in Setup
std::vector<OutputDesc> output_desc;

// TODO(michalz): Consider removing the (now) redundant fields and use the definition
std::string instance_name;
OpType op_type = OpType::COUNT;
OpPartitionId partition_index = -1;
Expand All @@ -97,6 +101,10 @@ using consumer_edge_t = TensorMeta;

// Second type of graph nodes.
struct TensorNode {
// NOTE: TensorNode doesn't define the storage device, but TensorNode is taken from OpSpec
// where it's unambiguously associated with a storage device.
const graph::DataNode *definition = nullptr;

TensorNodeId id;
std::string name;
producer_edge_t producer;
Expand Down Expand Up @@ -133,6 +141,8 @@ class DLL_PUBLIC OpGraph {

DLL_PUBLIC inline ~OpGraph() = default;

void Lower(const graph::OpGraph &definition);

/**
* @brief Adds an op with the input specification to the graph.
*/
Expand Down
79 changes: 79 additions & 0 deletions dali/pipeline/executor/op_graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,5 +643,84 @@ TEST_F(OpGraphTest, TestGetTensorOrigin) {
EXPECT_EQ(graph.GetTensorOrigin(4), origin_4);
}

inline bool operator==(const dali::TensorMeta &a, const dali::TensorMeta &b) {
return a.index == b.index && a.node == b.node && a.storage_device == b.storage_device;
}

void CheckEqual(const OpGraph &g1, const OpGraph &g2) {
EXPECT_EQ(g1.NumOp(), g2.NumOp()) << "The number of operator nodes differs.";
EXPECT_EQ(g1.NumTensor(), g2.NumTensor()) << "The number of tensor nodes differs.";
EXPECT_EQ(g1.NumOp(OpType::CPU), g2.NumOp(OpType::CPU)) << "The numberof CPU nodes differs.";
EXPECT_EQ(g1.NumOp(OpType::GPU), g2.NumOp(OpType::GPU)) << "The numberof GPU nodes differs.";
EXPECT_EQ(g1.NumOp(OpType::MIXED), g2.NumOp(OpType::MIXED))
<< "The numberof mixed nodes differs.";

if (::testing::Test::HasFailure())
return;

for (int i = 0; i < g1.NumOp(); i++) {
auto &n1 = g1.Node(i);
auto &n2 = g2.Node(i);
EXPECT_EQ(n1.id, n2.id) << " @ node " << i;
EXPECT_EQ(n1.instance_name, n2.instance_name) << " @ node " << i;
EXPECT_EQ(n1.spec.SchemaName(), n2.spec.SchemaName())<< " @ node " << i;
EXPECT_EQ(n1.children, n2.children) << " @ node " << i;
EXPECT_EQ(n1.parents, n2.parents) << " @ node " << i;
}
for (int i = 0; i < g1.NumTensor(); i++) {
auto &t1 = g1.Tensor(i);
auto &t2 = g2.Tensor(i);
EXPECT_EQ(t1.id, t2.id) << " @ node " << i;
EXPECT_EQ(t1.name, t2.name) << " @ node " << i;
EXPECT_EQ(t1.consumers, t2.consumers) << " @ node " << i;
EXPECT_EQ(t1.producer, t2.producer) << " @ node " << i;
}
}

TEST_F(OpGraphTest, Lowering) {
OpSpec spec0 = this->PrepareSpec(OpSpec("ExternalSource")
.AddArg("device", "cpu")
.AddArg("device_id", 0)
.AddOutput("data", "cpu"));

OpSpec spec1 = this->PrepareSpec(OpSpec("Copy")
.AddInput("data", "cpu")
.AddOutput("copy_0_data", "cpu"));

OpSpec spec2 = this->PrepareSpec(OpSpec("MakeContiguous")
.AddInput("copy_0_data", "cpu")
.AddOutput("contiguous_data", "cpu"));

OpSpec spec3 = this->PrepareSpec(OpSpec("PassthroughOp")
.AddInput("contiguous_data", "cpu")
.AddOutput("passthrough_data", "cpu"));

OpSpec spec4 = this->PrepareSpec(OpSpec("Copy")
.AddInput("passthrough_data", "cpu")
.AddOutput("copy_1_data", "cpu"));

graph::OpGraph::Builder b;
// This is the same graph as in TestGetTensorOrigin, but the topological order is not maintained.
b.Add("Copy1", spec4); // tensor node 4
b.Add("ExternalSource", spec0); // tensor node 0
b.Add("MakeContiguous", spec2); // tensor node 2
b.Add("Passthrough", spec3); // tensor node 3
b.Add("Copy0", spec1); // tensor node 1
b.AddOutput("copy_1_data_cpu");

auto def = std::move(b).GetGraph(true);
OpGraph lowered;
lowered.Lower(def);

OpGraph handmade;
handmade.AddOp(spec0, "ExternalSource"); // tensor node 0
handmade.AddOp(spec1, "Copy0"); // tensor node 1
handmade.AddOp(spec2, "MakeContiguous"); // tensor node 2
handmade.AddOp(spec3, "Passthrough"); // tensor node 3
handmade.AddOp(spec4, "Copy1");

CheckEqual(lowered, handmade);
}


} // namespace dali
45 changes: 30 additions & 15 deletions dali/pipeline/graph/op_graph2.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,24 @@ class DLL_PUBLIC OpGraph {
return data_nodes_;
}

/** Returns an OpNode with a matching instance name or nullptr. */
OpNode *GetOp(std::string_view instance_name) {
auto it = name2op_.find(instance_name);
if (it != name2op_.end())
return &*it->second;
else
return nullptr;
return GetOpImpl(instance_name);
}

/** Returns a DataNode with a matching name or nullptr.
*
* @param data_node_name
* @return DataNode*
*/
/** Returns an OpNode with a matching instance name or nullptr. */
const OpNode *GetOp(std::string_view instance_name) const {
return GetOpImpl(instance_name);
}

/** Returns a DataNode with a matching name or nullptr. */
DataNode *GetData(std::string_view data_node_name) {
auto it = name2data_.find(data_node_name);
if (it != name2data_.end())
return &*it->second;
else
return nullptr;
return GetDataImpl(data_node_name);
}

/** Returns a DataNode with a matching name or nullptr. */
const DataNode *GetData(std::string_view data_node_name) const {
return GetDataImpl(data_node_name);
}

/** Sorts the graph topologically and removes entries that do not contribute to essential nodes.
Expand Down Expand Up @@ -200,6 +199,22 @@ class DLL_PUBLIC OpGraph {
}

private:
OpNode *GetOpImpl(std::string_view instance_name) const {
auto it = name2op_.find(instance_name);
if (it != name2op_.end())
return &*it->second;
else
return nullptr;
}

DataNode *GetDataImpl(std::string_view data_node_name) const {
auto it = name2data_.find(data_node_name);
if (it != name2data_.end())
return &*it->second;
else
return nullptr;
}

void RemoveDataNodeReferences(OpNode &op);

OpNodeList op_nodes_;
Expand Down

0 comments on commit f850059

Please sign in to comment.