From 77533b919d098c0a0701caf947105d0dd846268c Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 4 Jun 2021 19:16:04 +0800 Subject: [PATCH] rewrite scalar_pow backward (#5099) * rewrite scalar_pow backward * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../autograd/gradient_funcs/scalar_pow.cpp | 72 +++++++++++++++++++ oneflow/core/framework/op_expr_helper.cpp | 25 +++++++ oneflow/core/framework/op_expr_helper.h | 6 ++ oneflow/python/test/modules/test_math_ops.py | 2 +- 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 oneflow/core/autograd/gradient_funcs/scalar_pow.cpp diff --git a/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp b/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp new file mode 100644 index 00000000000..03f021aaad5 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp @@ -0,0 +1,72 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" + +namespace oneflow { +namespace one { + +struct ScalarPowInterpState : public OpExprInterpState { + bool requires_grad; + double exponent; +}; + +class ScalarPow : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + const std::string& op_name = fw_op_expr->op_name(); + grad_op_ = JUST(op_expr_helper::ScalarPowGradOp(/*exponent=*/1.0, GradientOpName(op_name))); + return Maybe::Ok(); + } + + Maybe Capture(ScalarPowInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->exponent = JUST(composed_attrs.GetAttr("exponent")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); + } + + Maybe Apply(const ScalarPowInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + const auto& x = ctx->SavedTensors().at(0); + MutableAttrMap attrs; + JUST(attrs.SetAttr("exponent", ctx->exponent)); + in_grads->resize(1); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {x, out_grads.at(0)}, attrs)); + return Maybe::Ok(); + } + + private: + std::shared_ptr grad_op_; + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow", ScalarPow); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index 0fe244f79f5..9681f30b308 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -168,6 +168,31 @@ Maybe ReduceSumLikeOp(const std::vector& axis, const s .Build(); } +Maybe ScalarPowOp(const double& exponent) { + return ScalarPowOp(exponent, UniqueOpName("scalar_pow")); +} + +Maybe ScalarPowOp(const double& exponent, const std::string& name) { + return one::OpBuilder("scalar_pow", name) + .Input("in") + .Attr("exponent", exponent) + .Output("out") + .Build(); +} + +Maybe ScalarPowGradOp(const double& exponent) { + return ScalarPowGradOp(exponent, UniqueOpName("scalar_pow_grad")); +} + +Maybe ScalarPowGradOp(const double& exponent, const std::string& name) { + return one::OpBuilder("scalar_pow_grad", name) + .Input("x") + .Input("dy") + .Attr("exponent", exponent) + .Output("dx") + .Build(); +} + template<> Maybe ScalarMulOp(const float& scalar, const std::string& name) { return one::OpBuilder("scalar_mul", name) diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 82554c8c37e..df79362f016 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -64,6 +64,12 @@ Maybe ReduceSumOp(const std::vector& reduce_axes, cons Maybe ReduceSumLikeOp(const std::vector& axis); Maybe ReduceSumLikeOp(const std::vector& axis, const std::string& name); +Maybe ScalarPowOp(const double& exponent); +Maybe ScalarPowOp(const double& exponent, const std::string& name); + +Maybe ScalarPowGradOp(const double& exponent); +Maybe ScalarPowGradOp(const double& exponent, const std::string& name); + template Maybe ScalarAddOp(const T& scalar); diff --git a/oneflow/python/test/modules/test_math_ops.py b/oneflow/python/test/modules/test_math_ops.py index 295ba01fe08..3d2b8086a8f 100644 --- a/oneflow/python/test/modules/test_math_ops.py +++ b/oneflow/python/test/modules/test_math_ops.py @@ -391,7 +391,7 @@ def test_pow(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_pow, - # _test_pow_backward TODO:(zhaoluyang) >> rewrite scalar_pow op backward + _test_pow_backward, ] arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] arg_dict["device"] = ["cpu", "cuda"]