Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Sep 28, 2023
1 parent bcd1b16 commit 07d7182
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 44 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad,
Expand Down
41 changes: 1 addition & 40 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2628,7 +2628,6 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
Expand Down Expand Up @@ -2746,44 +2745,6 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct SqrtGradGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) =
dx * ddx *
(static_cast<ComplexType<T>>(-1) / out).unaryExpr(Conj<T>());
}
if (ddOut) {
auto ddout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) =
ddx * (static_cast<ComplexType<T>>(0.5) / out).unaryExpr(Conj<T>());
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
Expand Down Expand Up @@ -3723,7 +3684,7 @@ struct CudaSqrtGradFunctor<ComplexType<T>>
// dx = dout * 0.5 / out
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return one_half * dout / out;
return dout * conj(one_half / out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)

Expand Down

0 comments on commit 07d7182

Please sign in to comment.