Skip to content

Commit 0bd49a5

Browse files
authored
move net_op to operators/ (#3201)
* move net_op to operators
1 parent d953611 commit 0bd49a5

File tree

12 files changed

+76
-63
lines changed

12 files changed

+76
-63
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
3131
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
3232
add_dependencies(framework_py_proto framework_py_proto_init)
3333

34-
cc_library(net SRCS net.cc DEPS op_registry)
35-
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
36-
37-
cc_library(backward SRCS backward.cc DEPS net)
34+
cc_library(backward SRCS backward.cc DEPS net_op)
3835
cc_test(backward_test SRCS backward_test.cc DEPS backward)
3936
cc_library(paddle_pybind SHARED
4037
SRCS pybind.cc

paddle/framework/backward.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
#include "paddle/framework/backward.h"
1616
#include <list>
17-
#include "paddle/framework/net.h"
1817
#include "paddle/framework/op_registry.h"
18+
#include "paddle/operators/net_op.h"
1919

2020
namespace paddle {
2121
namespace framework {
@@ -32,7 +32,7 @@ static bool AllInSet(const std::vector<std::string>& names,
3232
}
3333

3434
static std::shared_ptr<OperatorBase> NOP() {
35-
auto net_op = std::make_shared<NetOp>();
35+
auto net_op = std::make_shared<operators::NetOp>();
3636
net_op->type_ = "@NOP@";
3737
net_op->CompleteAddOp();
3838
return net_op;
@@ -77,11 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
7777
}
7878

7979
// Returned gradient network
80-
auto net = std::make_shared<NetOp>();
80+
auto net = std::make_shared<operators::NetOp>();
8181

8282
if (forwardOp.IsNetOp()) {
8383
// Because forwardOp is a net op, it can static_cast.
84-
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
84+
auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
8585

8686
// Map from output gradient variable name to operator's indices in backward
8787
// net. That operator generates that variable.

paddle/framework/backward_test.cc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
#include "paddle/framework/backward.h"
1616

1717
#include <gtest/gtest.h>
18-
#include "paddle/framework/net.h"
1918
#include "paddle/framework/op_registry.h"
19+
#include "paddle/operators/net_op.h"
20+
#include "paddle/operators/type_alias.h"
2021

2122
namespace paddle {
2223
namespace framework {
@@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
7071
}
7172
};
7273

73-
class FcOp : public NetOp {
74+
class FcOp : public ops::NetOp {
7475
public:
7576
void Init() override {
7677
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
@@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) {
182183
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
183184
ASSERT_NE(no_input_gop, nullptr);
184185
ASSERT_TRUE(no_input_gop->IsNetOp());
185-
ASSERT_EQ(0UL, std::static_pointer_cast<f::NetOp>(no_input_gop)->ops_.size());
186+
ASSERT_EQ(0UL,
187+
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
186188
}
187189

188190
TEST(Backward, net_fc_backward_normal) {
@@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) {
191193
ASSERT_NE(fwd, nullptr);
192194
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
193195
ASSERT_TRUE(gop->IsNetOp());
194-
auto net = static_cast<f::NetOp *>(gop.get());
196+
auto net = static_cast<ops::NetOp *>(gop.get());
195197

196198
ASSERT_NO_THROW(net->DebugString());
197199

@@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
214216
ASSERT_NE(fwd, nullptr);
215217
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
216218
ASSERT_TRUE(gop->IsNetOp());
217-
auto net = static_cast<f::NetOp *>(gop.get());
219+
auto net = static_cast<ops::NetOp *>(gop.get());
218220

219221
ASSERT_NO_THROW(net->DebugString());
220222

@@ -228,15 +230,15 @@ TEST(Backward, net_fc_backward_not_have_b) {
228230
}
229231

230232
TEST(Backward, net_input_of_network_not_need_grad) {
231-
f::NetOp net;
233+
ops::NetOp net;
232234
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
233235
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
234236
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
235237
{"mul_tmp_1", "add_tmp_1", "hidden1"}, {}));
236238
net.CompleteAddOp();
237239
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
238240
ASSERT_TRUE(bwd->IsNetOp());
239-
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
241+
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
240242

241243
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
242244
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
@@ -253,22 +255,22 @@ TEST(Backward, net_input_of_network_not_need_grad) {
253255

254256
ASSERT_EQ(2UL, bwd_net->ops_.size());
255257
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
256-
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
258+
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
257259
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
258260
ASSERT_EQ(
259261
f::OperatorBase::EMPTY_VAR_NAME(),
260262
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
261263
}
262264

263265
TEST(Backward, net_shared_weight) {
264-
f::NetOp net;
266+
ops::NetOp net;
265267
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
266268
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
267269
net.CompleteAddOp();
268270

269271
auto bwd = f::Backward(net, {});
270272
ASSERT_TRUE(bwd->IsNetOp());
271-
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
273+
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
272274
ASSERT_EQ(3UL, bwd_net->ops_.size());
273275
ASSERT_EQ("add", bwd_net->ops_[2]->type_);
274276
}
@@ -285,23 +287,23 @@ TEST(Backward, op_all_input_are_not_need) {
285287
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
286288
auto backward = f::Backward(*fwd, {"X", "b"});
287289
ASSERT_TRUE(backward->IsNetOp());
288-
auto net = static_cast<f::NetOp *>(backward.get());
290+
auto net = static_cast<ops::NetOp *>(backward.get());
289291
ASSERT_TRUE(net->ops_.empty());
290292
}
291293

292294
TEST(Backward, op_all_output_are_not_need) {
293295
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
294296
auto backward = f::Backward(*fwd, {"Out"});
295297
ASSERT_TRUE(backward->IsNetOp());
296-
auto net = static_cast<f::NetOp *>(backward.get());
298+
auto net = static_cast<ops::NetOp *>(backward.get());
297299
ASSERT_TRUE(net->ops_.empty());
298300
}
299301

300302
TEST(Backward, op_part_of_output_are_not_need) {
301303
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
302304
auto backward = f::Backward(*fwd, {"Z"});
303305
ASSERT_TRUE(backward->IsNetOp());
304-
auto net = static_cast<f::NetOp *>(backward.get());
306+
auto net = static_cast<ops::NetOp *>(backward.get());
305307
ASSERT_EQ(net->ops_.size(), 2UL);
306308

307309
auto &fill_zero = *net->ops_[0];
@@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) {
341343
}
342344

343345
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
344-
f::NetOp net;
346+
ops::NetOp net;
345347
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
346348
{"mul_out1", "add_out1", "out1"}, {}));
347349
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
@@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
351353
net.CompleteAddOp();
352354
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
353355
ASSERT_TRUE(backward->IsNetOp());
354-
auto bwd_net = static_cast<f::NetOp *>(backward.get());
356+
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
355357
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
356358
auto &grad_fc = *bwd_net->ops_[0];
357359
EXPECT_EQ(grad_fc.inputs_.size(),

paddle/framework/pybind.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ limitations under the License. */
1717
#include <vector>
1818

1919
#include "paddle/framework/backward.h"
20-
#include "paddle/framework/net.h"
2120
#include "paddle/framework/op_registry.h"
2221
#include "paddle/framework/operator.h"
2322
#include "paddle/framework/scope.h"
2423
#include "paddle/framework/tensor_py.h"
24+
#include "paddle/operators/net_op.h"
25+
#include "paddle/operators/type_alias.h"
2526
#include "paddle/platform/enforce.h"
2627
#include "paddle/platform/place.h"
2728
#include "pybind11/numpy.h"
@@ -118,7 +119,9 @@ All parameter, weight, gradient are variables in Paddle.
118119
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
119120
py::return_value_policy::reference)
120121
.def("get_net",
121-
[](Variable &self) -> NetOp * { return self.GetMutable<NetOp>(); },
122+
[](Variable &self) -> ops::NetOp * {
123+
return self.GetMutable<ops::NetOp>();
124+
},
122125
py::return_value_policy::reference);
123126

124127
py::class_<Scope>(m, "Scope", "")
@@ -196,22 +199,24 @@ All parameter, weight, gradient are variables in Paddle.
196199

197200
ExposeOperator(operator_base);
198201

199-
py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net");
202+
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
200203

201204
net.def_static("create",
202-
[]() -> std::shared_ptr<NetOp> {
203-
auto retv = std::make_shared<NetOp>();
205+
[]() -> std::shared_ptr<ops::NetOp> {
206+
auto retv = std::make_shared<ops::NetOp>();
204207
retv->type_ = "plain_net";
205208
return retv;
206209
})
207-
.def("add_op", &NetOp::AddOp)
208-
.def("add_op",
209-
[](NetOp &self, const std::shared_ptr<NetOp> &net) -> void {
210-
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
211-
})
212-
.def("complete_add_op", &NetOp::CompleteAddOp)
210+
.def("add_op", &ops::NetOp::AddOp)
211+
.def(
212+
"add_op",
213+
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
214+
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
215+
})
216+
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
213217
.def("complete_add_op",
214-
[](std::shared_ptr<NetOp> &self) { self->CompleteAddOp(); });
218+
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
219+
215220
ExposeOperator(net);
216221

217222
m.def("unique_integer", UniqueIntegerGenerator);

paddle/operators/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ function(op_library TARGET)
4141
endif()
4242
endfunction()
4343

44+
cc_library(net_op SRCS net_op.cc DEPS op_registry)
45+
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
46+
4447
op_library(add_op SRCS add_op.cc add_op.cu)
4548
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
4649

@@ -59,6 +62,6 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
5962

6063
op_library(fc_op
6164
SRCS fc_op.cc
62-
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
63-
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
65+
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op)
66+
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net_op)
6467
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)

paddle/framework/net.cc renamed to paddle/operators/net_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
limitations under the License.
1515
*/
1616

17-
#include "paddle/framework/net.h"
17+
#include "paddle/operators/net_op.h"
1818
#include "paddle/framework/op_registry.h"
1919

2020
namespace paddle {
21-
namespace framework {
21+
namespace operators {
2222

2323
void NetOp::CompleteAddOp(bool calc) {
2424
add_op_done_ = true;
@@ -74,5 +74,5 @@ std::string NetOp::DebugString() const {
7474

7575
bool NetOp::IsNetOp() const { return true; }
7676

77-
} // namespace framework
77+
} // namespace operators
7878
} // namespace paddle

paddle/framework/net.h renamed to paddle/operators/net_op.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <paddle/framework/op_desc.pb.h>
18-
#include <paddle/framework/operator.h>
17+
#include "paddle/framework/op_desc.pb.h"
1918
#include "paddle/framework/op_proto.pb.h"
2019
#include "paddle/framework/op_registry.h"
20+
#include "paddle/framework/operator.h"
2121
#include "paddle/framework/scope.h"
22+
#include "paddle/operators/type_alias.h"
2223
#include "paddle/platform/device_context.h"
2324

2425
namespace paddle {
25-
namespace framework {
26+
namespace operators {
27+
2628
/**
2729
* @brief Network is also a type of Operator
2830
*
@@ -37,13 +39,13 @@ namespace framework {
3739
* This is the base class of network, all the networks should implement the APIs
3840
* it defines.
3941
*/
40-
class NetOp : public OperatorBase {
41-
public:
42+
class NetOp : public framework::OperatorBase {
43+
public:
4244
/**
4345
* Infer all the operators' input and output variables' shapes, will be called
4446
* before every mini-batch
4547
*/
46-
void InferShape(const Scope& scope) const override {
48+
void InferShape(const framework::Scope& scope) const override {
4749
for (auto& op : ops_) {
4850
op->InferShape(scope);
4951
}
@@ -56,7 +58,7 @@ class NetOp : public OperatorBase {
5658
* scope will be used instead. If no OpContext is provicded, default context
5759
* will be used.
5860
*/
59-
void Run(const Scope& scope,
61+
void Run(const framework::Scope& scope,
6062
const platform::DeviceContext& dev_ctx) const override {
6163
for (auto& op : ops_) {
6264
op->Run(scope, dev_ctx);
@@ -88,7 +90,7 @@ class NetOp : public OperatorBase {
8890

8991
std::vector<std::shared_ptr<OperatorBase>> ops_;
9092

91-
private:
93+
private:
9294
bool add_op_done_{false};
9395

9496
template <typename T, typename KeyType>
@@ -97,5 +99,5 @@ class NetOp : public OperatorBase {
9799
}
98100
};
99101

100-
} // namespace framework
102+
} // namespace operators
101103
} // namespace paddle
File renamed without changes.

paddle/framework/net_op_test.cc renamed to paddle/operators/net_op_test.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
#include "paddle/operators/net_op.h"
2+
13
#include <gtest/gtest.h>
2-
#include <paddle/framework/net.h>
3-
#include <paddle/framework/op_registry.h>
4-
#include <paddle/framework/operator.h>
4+
5+
#include "paddle/framework/op_registry.h"
6+
#include "paddle/framework/operator.h"
57

68
namespace paddle {
7-
namespace framework {
9+
namespace operators {
810

911
static int infer_shape_cnt = 0;
1012
static int run_cnt = 0;
1113

1214
class TestOp : public OperatorBase {
13-
public:
15+
public:
1416
void InferShape(const framework::Scope& scope) const override {
1517
++infer_shape_cnt;
1618
}
@@ -21,7 +23,7 @@ class TestOp : public OperatorBase {
2123
};
2224

2325
class EmptyOp : public OperatorBase {
24-
public:
26+
public:
2527
void InferShape(const Scope& scope) const override {}
2628
void Run(const Scope& scope,
2729
const platform::DeviceContext& dev_ctx) const override {}
@@ -73,7 +75,7 @@ TEST(OpKernel, all) {
7375
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
7476
}
7577

76-
TEST(Net, insert_op) {
78+
TEST(NetOp, insert_op) {
7779
NetOp net;
7880
auto op1 = std::make_shared<EmptyOp>();
7981
op1->inputs_ = {"x", "w1", "b1"};
@@ -85,5 +87,5 @@ TEST(Net, insert_op) {
8587
ASSERT_EQ(3UL, net.ops_.size());
8688
}
8789

88-
} // namespace framework
90+
} // namespace operators
8991
} // namespace paddle

0 commit comments

Comments
 (0)