11/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
22
3- Licensed under the Apache License, Version 2.0 (the "License");
4- you may not use this file except in compliance with the License.
5- You may obtain a copy of the License at
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
66
7- http://www.apache.org/licenses/LICENSE-2.0
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License. */
814
9- Unless required by applicable law or agreed to in writing, software
10- distributed under the License is distributed on an "AS IS" BASIS,
11- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- See the License for the specific language governing permissions and
13- limitations under the License. */
1415#include " paddle/operators/softmax_op.h"
1516
1617namespace paddle {
@@ -19,12 +20,13 @@ namespace operators {
1920class SoftmaxOp : public OperatorWithKernel {
2021protected:
2122 void InferShape (const InferShapeContext &ctx) const override {
22- PADDLE_ENFORCE (ctx.InputSize () == 1 , " Only one input is need for softmax" );
23- PADDLE_ENFORCE (ctx.Input <Tensor>(0 )->dims ().size () == 2 ,
23+ PADDLE_ENFORCE (ctx.InputSize () == 1UL ,
24+ " Only one input is need for softmax" );
25+ PADDLE_ENFORCE (ctx.Input <Tensor>(" X" )->dims ().size () == 2UL ,
2426 " The input of softmax op must be matrix" );
25- PADDLE_ENFORCE (ctx.OutputSize () == 1 ,
27+ PADDLE_ENFORCE (ctx.OutputSize () == 1UL ,
2628 " Only one output is need for softmax" );
27- ctx.Output <Tensor>(0 )->Resize (ctx.Input <Tensor>(0 )->dims ());
29+ ctx.Output <Tensor>(" Y " )->Resize (ctx.Input <Tensor>(" X " )->dims ());
2830 }
2931};
3032
@@ -40,16 +42,27 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
4042
4143class SoftmaxOpGrad : public OperatorWithKernel {
4244protected:
43- void InferShape (const InferShapeContext &ctx) const override {}
44- std::string DebugString () const override {
45- LOG (INFO) << " SoftmaxOpGrad" ;
46- return " " ;
45+ void InferShape (const InferShapeContext &ctx) const override {
46+ PADDLE_ENFORCE (ctx.InputSize () == 3UL ,
47+ " Input of SoftmaxOpGrad should be 3, X, Y, YG" );
48+ PADDLE_ENFORCE (ctx.OutputSize () == 1UL ,
49+ " Output of SoftmaxOpGrad should be 1" );
50+ PADDLE_ENFORCE (ctx.InputVar (" Y" ) != nullptr , " Input(Y) should not be null" );
51+ PADDLE_ENFORCE (ctx.InputVar (GRAD_VAR_NAME (" Y" )) != nullptr ,
52+ " Input(Y@GRAD) should not be null" );
53+ PADDLE_ENFORCE (ctx.Input <Tensor>(" Y" )->dims () ==
54+ ctx.Input <Tensor>(GRAD_VAR_NAME (" Y" ))->dims (),
55+ " the shape of Input(0) and Input(1) should be the same" );
56+ ctx.Output <Tensor>(GRAD_VAR_NAME (" X" ))
57+ ->Resize (ctx.Input <Tensor>(" Y" )->dims ());
4758 }
4859};
4960
5061} // namespace operators
5162} // namespace paddle
5263
5364REGISTER_OP (softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
54- REGISTER_GRADIENT_OP (softmax, softmax_grad, ops::SoftmaxOpGrad);
5565REGISTER_OP_CPU_KERNEL (softmax, ops::SoftmaxKernel<ops::CPUPlace, float >);
66+ REGISTER_GRADIENT_OP (softmax, softmax_grad, ops::SoftmaxOpGrad);
67+ REGISTER_OP_CPU_KERNEL (softmax_grad,
68+ ops::SoftmaxGradKernel<ops::CPUPlace, float >);
0 commit comments