|
| 1 | +# Network Design |
| 2 | + |
| 3 | +`Network` is the container and controller of a set of operators, |
| 4 | +user can build a real network from a `NetDesc` which is a protobuf message |
| 5 | +and use `Network.Run()` to run all the operators in the network. |
| 6 | + |
| 7 | +A network object knows all Operators belonging to this network. Variables, |
| 8 | +which are inputs and outputs of these operators, |
| 9 | +are created and managed by a hierarchy of Scope objects. |
| 10 | + |
| 11 | +# API |
| 12 | + |
| 13 | +## Net |
| 14 | +To make the `Network` extendable, a base class is defined like this |
| 15 | + |
| 16 | +```c++ |
| 17 | +// operator's index stored in a network. |
| 18 | +typedef int OpIndex; |
| 19 | + |
| 20 | +// The minimum a network should be implemented. |
| 21 | +class Net { |
| 22 | + public: |
| 23 | + // run all the operators and return success(true) or not, with all the |
| 24 | + // variables are located in `scope`. `context` describes the detail execution |
| 25 | + // environment for ops. `begin` and `end` specify the scope of `ops_` to run, |
| 26 | + // If no positive indexes are provided, all operators in `ops_` will run. |
| 27 | + virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, |
| 28 | + OpIndex end = -1) const = 0; |
| 29 | + |
| 30 | + // Add an Operator according to `def`. |
| 31 | + virtual OpIndex AddOp(const proto::OpDef &def) = 0; |
| 32 | + |
| 33 | + // Add optimizer operators acctording to `attrs`. |
| 34 | + virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0; |
| 35 | + |
| 36 | + // Add backward operators. |
| 37 | + virtual Error AddBackwardOps() = 0; |
| 38 | + |
| 39 | + // Infer the shapes of variables required by operators in the network. The |
| 40 | + // `scope` will be mutated according to the inferred shapes. |
| 41 | + |
| 42 | + static std::unique_ptr<Net> Create(const NetDesc &def = NetDesc()); |
| 43 | +}; |
| 44 | +``` |
| 45 | +
|
| 46 | +All network implementations should build networks from a protobuf message which |
| 47 | +describes the structure of a real network; `Run` method should be implemented by |
| 48 | +all implementations to offer a universal method to forward or backward compute a network. |
| 49 | +
|
| 50 | +`Net::Create` is a method of factory pattern and can be implemented like |
| 51 | +
|
| 52 | +```c++ |
| 53 | +std::unique<Net> Net::Create(const NetDesc& def) { |
| 54 | + switch (def.model_type()) { |
| 55 | + case NN: |
| 56 | + return new Network(def); |
| 57 | + case Recursive: |
| 58 | + return new RecursiveNet(def); |
| 59 | + case Recurrent: |
| 60 | + return new RecurrentNet(def); |
| 61 | + } |
| 62 | + return nullptr; |
| 63 | +} |
| 64 | +``` |
| 65 | + |
| 66 | +Network is designed as the container of operators. to make it more extendable, |
| 67 | +we decouple it from the related variable resources. |
| 68 | + |
| 69 | +`Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes. |
| 70 | + |
| 71 | +Finally, `Net` can be used as followed |
| 72 | + |
| 73 | +```c++ |
| 74 | +Scope default_scope; |
| 75 | +OpContext default_context; |
| 76 | +auto net = Net::CreateNet(def); |
| 77 | + |
| 78 | +if (net) { |
| 79 | + net.Run(&default_scope, &default_context); |
| 80 | +} |
| 81 | +``` |
| 82 | + |
| 83 | +## `PlainNet` as a simple implementation of `BaseNet` |
| 84 | + |
| 85 | +A very basic implementation is as follows. All it does is simply to run every operators in sequence. |
| 86 | + |
| 87 | +```c++ |
| 88 | +class PlainNet : public Net { |
| 89 | + public: |
| 90 | + // Create a network describe by `def`. NetDesc is the definition of a network. |
| 91 | + PlainNet(const NetDesc &def); |
| 92 | + |
| 93 | + // Infer all the operators' input and output varialbes' shapes, will be called before every mini-batch |
| 94 | + training. |
| 95 | + virtual Error InferShape(Scope *scope) override; |
| 96 | + |
| 97 | + // Run all the operators with the `scope`, if no scope is provided, default |
| 98 | + // scope will be used instead. If no OpContext is provicded, default context will be used. |
| 99 | + virtual Error Run(Scope *scope = nullptr, OpContext *context=nullptr, OpIndex begin = -1, |
| 100 | + OpIndex end = -1) const override; |
| 101 | + |
| 102 | + virtual OpIndex AddOp(const proto::OpDef &def) override; |
| 103 | + |
| 104 | + virtual Error AddOptimizerOps(const OptAttrs &attrs) override; |
| 105 | + |
| 106 | + virtual Error AddBackwardOps() override; |
| 107 | + |
| 108 | + protected: |
| 109 | + // Create operators accordding to `def`, will be called by the constructor. |
| 110 | + Error BuildNet(const NetDesc &def); |
| 111 | + |
| 112 | + // Add a operator which is identified as `type` and has attributes described |
| 113 | + // in `attrs`, the `inputs` are the keys of readonly input variables, |
| 114 | + // `outputs` are keys of mutable output variables. An `OpIndex` will be |
| 115 | + // returned to indicate the offset of the new operator in `ops_`. |
| 116 | + OpIndex AddOp(const std::string &type, const std::vector<string> &inputs, |
| 117 | + const std::vector<string> &outputs, |
| 118 | + const OprAttr &attrs = OprAttr()); |
| 119 | + |
| 120 | + private: |
| 121 | + // the operators owned by `Network`. |
| 122 | + std::vector<Operator> ops_; |
| 123 | +}; |
| 124 | +``` |
| 125 | +
|
| 126 | +`PlainNet` will create operators so that a private member `ops_` is defined, |
| 127 | +the operators are created by `CreateNet`, and each operator is created by `AddOp`. |
| 128 | +
|
| 129 | +
|
| 130 | +## PlainNet Usage |
| 131 | +`PlainNet` can be used to define and run a network as follows |
| 132 | +
|
| 133 | +```c++ |
| 134 | +// create an empty scope located on CPU device. |
| 135 | +Scope scope(CPUPlace()); |
| 136 | +
|
| 137 | +// create and init variables described in `net_desc`. |
| 138 | +scope.CreateVariables(net_desc); |
| 139 | +scope.InitVariables(net_desc); |
| 140 | +
|
| 141 | +// create a network according to `net_desc` |
| 142 | +auto net = Net::CreateNet(net_desc); |
| 143 | +// Add more operators if needed. |
| 144 | +net->AddOp(add...); |
| 145 | +net->AddOp(fc...); |
| 146 | +
|
| 147 | +net->AddBackwardOps(); |
| 148 | +net->AddOptimizerOps(); |
| 149 | +
|
| 150 | +// run the network providing the `scope`. |
| 151 | +net.Run(&scope); |
| 152 | +``` |
| 153 | + |
| 154 | +## `NetBuilder` as a C++ syntax wrapper |
| 155 | +This is a detailed description of the user-related C++ network API, and may not needed in the prototype development stage. |
| 156 | + |
| 157 | +The `NetBuilder` will give users a much simpler syntax as follows to create a network, and demonstrates how to use the `BaseNet`'s raw interfaces. |
| 158 | + |
| 159 | +```c++ |
| 160 | +Variable* fc_out = builder.AddOp("fc", input=image, size=100, activation="Sigmoid"); |
| 161 | +Variable* prediction = builder.AddOp("fc", input=fc_out, size=10, activation="Sigmoid"); |
| 162 | +Variable* loss = builder.AddOp("cross_entropy", input=prediction, label=label); |
| 163 | +Variable* avg_loss = builder.AddOp("mean", loss); |
| 164 | + |
| 165 | +builder.BackwardFrom(avg_loss) |
| 166 | +builder.AddOptimization(1e-4, "adam"); |
| 167 | +builder.Run(); |
| 168 | +``` |
| 169 | + |
| 170 | +`NetBuilder` will call `Net` 's virtual functions to change the real network structure, here is a sample definition |
| 171 | + |
| 172 | +```c++ |
| 173 | +class NetBuilder final { |
| 174 | + public: |
| 175 | + NetBuilder(Net* net) : net_(net) {} |
| 176 | + |
| 177 | + Variable* AddOp(const string& type, const vector<Variable>& inputs, |
| 178 | + size_t size, Activation act) { |
| 179 | + // much code here. |
| 180 | + // ... |
| 181 | + net_->AddOp(def); |
| 182 | + need_rebuild_net_ = true; |
| 183 | + net_->InferShape(); |
| 184 | + // ... |
| 185 | + } |
| 186 | + |
| 187 | + Error BackwardFrom(const Variable& cost); |
| 188 | + |
| 189 | + Error Run(Scope* scope, OpContext* context, bool need_backward = true) { |
| 190 | + // backward. |
| 191 | + if (need_backward) { |
| 192 | + if (need_rebuild_net_) { |
| 193 | + AddBackwardOps(); |
| 194 | + AddOptimizerOps(); |
| 195 | + } |
| 196 | + net_->Run(scope, context); |
| 197 | + return; |
| 198 | + } |
| 199 | + // just forward. |
| 200 | + net_->Run(scope, context, 0, last_forward_op_); |
| 201 | + } |
| 202 | + |
| 203 | + protected: |
| 204 | + Error AddBackwardOps(); |
| 205 | + Error AddOptimizerOps(); |
| 206 | + |
| 207 | + private: |
| 208 | + Net* net_; |
| 209 | + OpIndex last_forward_op_{-1}; |
| 210 | + bool need_rebuild_net_{true}; |
| 211 | +} |
| 212 | +``` |
| 213 | +
|
| 214 | +## Compatibility with RNN |
| 215 | +
|
| 216 | +Benefitting from the decoupling of `PlainNet.Run` and `Scope`, `PlainNet` is compatible with future RNN design, |
| 217 | +for example we can implement a simple recurrent neural network as follows |
| 218 | +
|
| 219 | +```c++ |
| 220 | +// copy some `vars` form `source` to `target` |
| 221 | +void Copy(const Scope &source, Scope &target, |
| 222 | + const std::vector<std::string> &vars); |
| 223 | +
|
| 224 | +Scope default_scope; |
| 225 | +// some initial mutations on `default_scope` here. |
| 226 | +
|
| 227 | +auto rnn_step_net = PlainNet(rnn_step_net_def); |
| 228 | +
|
| 229 | +// Create rnn's states, the last scope is used to store rnn outputs. |
| 230 | +Scope *rnn_states = new Scope[num_states + 1]; |
| 231 | +
|
| 232 | +for (int i = 0; i < num_states + 1; i++) { |
| 233 | + // Initialize all rnn state scopes, copy parameters and so on. |
| 234 | + rnn_states[i].CreateVars(rnn_step_net_def); |
| 235 | + Copy(default_scope, rnn_states[i], rnn_related_vars); |
| 236 | + // Prepare rnn's inlinks, just copy inlink variables to each state. |
| 237 | + Copy(default_scope, rnn_states[i], inlink_vars); |
| 238 | +} |
| 239 | +
|
| 240 | +// Run the rnn. |
| 241 | +for (int i = 0; i < num_states; i++) { |
| 242 | + rnn_step_net.Run(rnn_states[i]); |
| 243 | + // Copy current state's state variables to next state, the related variables |
| 244 | + // are named like "previous_state_xxx". |
| 245 | + Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars) |
| 246 | +} |
| 247 | +
|
| 248 | +// Copy rnn's final outputs to `default_scope`. |
| 249 | +Copy(rnn_states[num_states], default_scope, outlink_vars); |
| 250 | +``` |
0 commit comments