Skip to content

Commit 06156da

Browse files
authored
net design with NetBuilder (#2598)
* move net_design to framework * change CreateNet result to unique_ptr * rename "ScratchNet" -> "PlainNet" * add three methods to NetBase * add NetBuilder * add InferShape to NetBuilder.Run * rename ApplyGradient, ApplyOptimizer -> AddGradientOps, AddOptimiz * rename PlainNet::CreateNet -> BuildNet * add Error and other rename actions
1 parent 0140eb9 commit 06156da

File tree

1 file changed

+250
-0
lines changed

1 file changed

+250
-0
lines changed

paddle/framework/net_design.md

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Network Design
2+
3+
`Network` is the container and controller of a set of operators,
4+
user can build a real network from a `NetDesc` which is a protobuf message
5+
and use `Network.Run()` to run all the operators in the network.
6+
7+
A network object knows all Operators belonging to this network. Variables,
8+
which are inputs and outputs of these operators,
9+
are created and managed by a hierarchy of Scope objects.
10+
11+
# API
12+
13+
## Net
14+
To make the `Network` extendable, a base class is defined like this
15+
16+
```c++
17+
// operator's index stored in a network.
18+
typedef int OpIndex;
19+
20+
// The minimum a network should be implemented.
21+
class Net {
22+
public:
23+
// run all the operators and return success(true) or not, with all the
24+
// variables are located in `scope`. `context` describes the detail execution
25+
// environment for ops. `begin` and `end` specify the scope of `ops_` to run,
26+
// If no positive indexes are provided, all operators in `ops_` will run.
27+
virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1,
28+
OpIndex end = -1) const = 0;
29+
30+
// Add an Operator according to `def`.
31+
virtual OpIndex AddOp(const proto::OpDef &def) = 0;
32+
33+
// Add optimizer operators acctording to `attrs`.
34+
virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0;
35+
36+
// Add backward operators.
37+
virtual Error AddBackwardOps() = 0;
38+
39+
// Infer the shapes of variables required by operators in the network. The
40+
// `scope` will be mutated according to the inferred shapes.
41+
42+
static std::unique_ptr<Net> Create(const NetDesc &def = NetDesc());
43+
};
44+
```
45+
46+
All network implementations should build networks from a protobuf message which
47+
describes the structure of a real network; `Run` method should be implemented by
48+
all implementations to offer a universal method to forward or backward compute a network.
49+
50+
`Net::Create` is a method of factory pattern and can be implemented like
51+
52+
```c++
53+
std::unique<Net> Net::Create(const NetDesc& def) {
54+
switch (def.model_type()) {
55+
case NN:
56+
return new Network(def);
57+
case Recursive:
58+
return new RecursiveNet(def);
59+
case Recurrent:
60+
return new RecurrentNet(def);
61+
}
62+
return nullptr;
63+
}
64+
```
65+
66+
Network is designed as the container of operators. to make it more extendable,
67+
we decouple it from the related variable resources.
68+
69+
`Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes.
70+
71+
Finally, `Net` can be used as followed
72+
73+
```c++
74+
Scope default_scope;
75+
OpContext default_context;
76+
auto net = Net::CreateNet(def);
77+
78+
if (net) {
79+
net.Run(&default_scope, &default_context);
80+
}
81+
```
82+
83+
## `PlainNet` as a simple implementation of `BaseNet`
84+
85+
A very basic implementation is as follows. All it does is simply to run every operators in sequence.
86+
87+
```c++
88+
class PlainNet : public Net {
89+
public:
90+
// Create a network describe by `def`. NetDesc is the definition of a network.
91+
PlainNet(const NetDesc &def);
92+
93+
// Infer all the operators' input and output varialbes' shapes, will be called before every mini-batch
94+
training.
95+
virtual Error InferShape(Scope *scope) override;
96+
97+
// Run all the operators with the `scope`, if no scope is provided, default
98+
// scope will be used instead. If no OpContext is provicded, default context will be used.
99+
virtual Error Run(Scope *scope = nullptr, OpContext *context=nullptr, OpIndex begin = -1,
100+
OpIndex end = -1) const override;
101+
102+
virtual OpIndex AddOp(const proto::OpDef &def) override;
103+
104+
virtual Error AddOptimizerOps(const OptAttrs &attrs) override;
105+
106+
virtual Error AddBackwardOps() override;
107+
108+
protected:
109+
// Create operators accordding to `def`, will be called by the constructor.
110+
Error BuildNet(const NetDesc &def);
111+
112+
// Add a operator which is identified as `type` and has attributes described
113+
// in `attrs`, the `inputs` are the keys of readonly input variables,
114+
// `outputs` are keys of mutable output variables. An `OpIndex` will be
115+
// returned to indicate the offset of the new operator in `ops_`.
116+
OpIndex AddOp(const std::string &type, const std::vector<string> &inputs,
117+
const std::vector<string> &outputs,
118+
const OprAttr &attrs = OprAttr());
119+
120+
private:
121+
// the operators owned by `Network`.
122+
std::vector<Operator> ops_;
123+
};
124+
```
125+
126+
`PlainNet` will create operators so that a private member `ops_` is defined,
127+
the operators are created by `CreateNet`, and each operator is created by `AddOp`.
128+
129+
130+
## PlainNet Usage
131+
`PlainNet` can be used to define and run a network as follows
132+
133+
```c++
134+
// create an empty scope located on CPU device.
135+
Scope scope(CPUPlace());
136+
137+
// create and init variables described in `net_desc`.
138+
scope.CreateVariables(net_desc);
139+
scope.InitVariables(net_desc);
140+
141+
// create a network according to `net_desc`
142+
auto net = Net::CreateNet(net_desc);
143+
// Add more operators if needed.
144+
net->AddOp(add...);
145+
net->AddOp(fc...);
146+
147+
net->AddBackwardOps();
148+
net->AddOptimizerOps();
149+
150+
// run the network providing the `scope`.
151+
net.Run(&scope);
152+
```
153+
154+
## `NetBuilder` as a C++ syntax wrapper
155+
This is a detailed description of the user-related C++ network API, and may not needed in the prototype development stage.
156+
157+
The `NetBuilder` will give users a much simpler syntax as follows to create a network, and demonstrates how to use the `BaseNet`'s raw interfaces.
158+
159+
```c++
160+
Variable* fc_out = builder.AddOp("fc", input=image, size=100, activation="Sigmoid");
161+
Variable* prediction = builder.AddOp("fc", input=fc_out, size=10, activation="Sigmoid");
162+
Variable* loss = builder.AddOp("cross_entropy", input=prediction, label=label);
163+
Variable* avg_loss = builder.AddOp("mean", loss);
164+
165+
builder.BackwardFrom(avg_loss)
166+
builder.AddOptimization(1e-4, "adam");
167+
builder.Run();
168+
```
169+
170+
`NetBuilder` will call `Net` 's virtual functions to change the real network structure, here is a sample definition
171+
172+
```c++
173+
class NetBuilder final {
174+
public:
175+
NetBuilder(Net* net) : net_(net) {}
176+
177+
Variable* AddOp(const string& type, const vector<Variable>& inputs,
178+
size_t size, Activation act) {
179+
// much code here.
180+
// ...
181+
net_->AddOp(def);
182+
need_rebuild_net_ = true;
183+
net_->InferShape();
184+
// ...
185+
}
186+
187+
Error BackwardFrom(const Variable& cost);
188+
189+
Error Run(Scope* scope, OpContext* context, bool need_backward = true) {
190+
// backward.
191+
if (need_backward) {
192+
if (need_rebuild_net_) {
193+
AddBackwardOps();
194+
AddOptimizerOps();
195+
}
196+
net_->Run(scope, context);
197+
return;
198+
}
199+
// just forward.
200+
net_->Run(scope, context, 0, last_forward_op_);
201+
}
202+
203+
protected:
204+
Error AddBackwardOps();
205+
Error AddOptimizerOps();
206+
207+
private:
208+
Net* net_;
209+
OpIndex last_forward_op_{-1};
210+
bool need_rebuild_net_{true};
211+
}
212+
```
213+
214+
## Compatibility with RNN
215+
216+
Benefitting from the decoupling of `PlainNet.Run` and `Scope`, `PlainNet` is compatible with future RNN design,
217+
for example we can implement a simple recurrent neural network as follows
218+
219+
```c++
220+
// copy some `vars` form `source` to `target`
221+
void Copy(const Scope &source, Scope &target,
222+
const std::vector<std::string> &vars);
223+
224+
Scope default_scope;
225+
// some initial mutations on `default_scope` here.
226+
227+
auto rnn_step_net = PlainNet(rnn_step_net_def);
228+
229+
// Create rnn's states, the last scope is used to store rnn outputs.
230+
Scope *rnn_states = new Scope[num_states + 1];
231+
232+
for (int i = 0; i < num_states + 1; i++) {
233+
// Initialize all rnn state scopes, copy parameters and so on.
234+
rnn_states[i].CreateVars(rnn_step_net_def);
235+
Copy(default_scope, rnn_states[i], rnn_related_vars);
236+
// Prepare rnn's inlinks, just copy inlink variables to each state.
237+
Copy(default_scope, rnn_states[i], inlink_vars);
238+
}
239+
240+
// Run the rnn.
241+
for (int i = 0; i < num_states; i++) {
242+
rnn_step_net.Run(rnn_states[i]);
243+
// Copy current state's state variables to next state, the related variables
244+
// are named like "previous_state_xxx".
245+
Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars)
246+
}
247+
248+
// Copy rnn's final outputs to `default_scope`.
249+
Copy(rnn_states[num_states], default_scope, outlink_vars);
250+
```

0 commit comments

Comments
 (0)