Skip to content

Commit

Permalink
rewrite scalar_pow backward (#5099)
Browse files Browse the repository at this point in the history
* rewrite scalar_pow backward

* refine

* refine

Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
Flowingsun007 and oneflow-ci-bot authored Jun 4, 2021
1 parent 1f9db98 commit 77533b9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
72 changes: 72 additions & 0 deletions oneflow/core/autograd/gradient_funcs/scalar_pow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"

namespace oneflow {
namespace one {

struct ScalarPowInterpState : public OpExprInterpState {
bool requires_grad;
double exponent;
};

class ScalarPow : public OpExprGradFunction<ScalarPowInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(op_expr_helper::ScalarPowGradOp(/*exponent=*/1.0, GradientOpName(op_name)));
return Maybe<void>::Ok();
}

Maybe<void> Capture(ScalarPowInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->exponent = JUST(composed_attrs.GetAttr<double>("exponent"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const ScalarPowInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("exponent", ctx->exponent));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {x, out_grads.at(0)}, attrs));
return Maybe<void>::Ok();
}

private:
std::shared_ptr<OpExpr> grad_op_;
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow", ScalarPow);

} // namespace one
} // namespace oneflow
25 changes: 25 additions & 0 deletions oneflow/core/framework/op_expr_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ Maybe<one::UserOpExpr> ReduceSumLikeOp(const std::vector<int32_t>& axis, const s
.Build();
}

Maybe<one::UserOpExpr> ScalarPowOp(const double& exponent) {
return ScalarPowOp(exponent, UniqueOpName("scalar_pow"));
}

Maybe<one::UserOpExpr> ScalarPowOp(const double& exponent, const std::string& name) {
return one::OpBuilder("scalar_pow", name)
.Input("in")
.Attr<double>("exponent", exponent)
.Output("out")
.Build();
}

Maybe<one::UserOpExpr> ScalarPowGradOp(const double& exponent) {
return ScalarPowGradOp(exponent, UniqueOpName("scalar_pow_grad"));
}

Maybe<one::UserOpExpr> ScalarPowGradOp(const double& exponent, const std::string& name) {
return one::OpBuilder("scalar_pow_grad", name)
.Input("x")
.Input("dy")
.Attr<double>("exponent", exponent)
.Output("dx")
.Build();
}

template<>
Maybe<one::UserOpExpr> ScalarMulOp(const float& scalar, const std::string& name) {
return one::OpBuilder("scalar_mul", name)
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/framework/op_expr_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ Maybe<one::UserOpExpr> ReduceSumOp(const std::vector<int32_t>& reduce_axes, cons
Maybe<one::UserOpExpr> ReduceSumLikeOp(const std::vector<int32_t>& axis);
Maybe<one::UserOpExpr> ReduceSumLikeOp(const std::vector<int32_t>& axis, const std::string& name);

Maybe<one::UserOpExpr> ScalarPowOp(const double& exponent);
Maybe<one::UserOpExpr> ScalarPowOp(const double& exponent, const std::string& name);

Maybe<one::UserOpExpr> ScalarPowGradOp(const double& exponent);
Maybe<one::UserOpExpr> ScalarPowGradOp(const double& exponent, const std::string& name);

template<typename T>
Maybe<one::UserOpExpr> ScalarAddOp(const T& scalar);

Expand Down
2 changes: 1 addition & 1 deletion oneflow/python/test/modules/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_pow(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_pow,
# _test_pow_backward TODO:(zhaoluyang) >> rewrite scalar_pow op backward
_test_pow_backward,
]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
Expand Down

0 comments on commit 77533b9

Please sign in to comment.