diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d3cf1cbcb34c1..8e274ece726a1 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -303,7 +303,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 66480018a5273..baf5fcbd1edd3 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -197,7 +197,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6b77c31d38d4a..8031e480f6c98 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -599,6 +599,32 @@ struct STanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct STanhGradFunctor> + : public BaseActivationFunctor> { + float scale_a; + float scale_b; + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto a = static_cast>(scale_a); // NOLINT + auto b = static_cast>(scale_b); + auto temp = (a * x).tanh() * (a * x).tanh(); + dx.device(d) = + dout * + (a * b * (static_cast>(1) - temp)).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } @@ -3560,6 +3586,32 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSTanhGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + ComplexType dout = static_cast>(arg_dout); + ComplexType x = static_cast>(arg_x); + ComplexType a = static_cast>(scale_a); + ComplexType b = static_cast>(scale_b); + ComplexType temp = tanh(a * x); + return static_cast>(dout * + conj(a * b * (one - temp * temp))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index d592dfad0a52d..f7f8b9ad49687 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -379,7 +379,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 000428268bbb1..d370c6125f095 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -246,7 +246,7 @@ PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8b74a2900502b..a72b9b81ab43f 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3634,7 +3634,13 @@ def setUp(self): scale_b = self.get_scale_b() np.random.seed(1024) - x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype is np.complex64 or self.dtype is np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) + else: + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) # The same reason with TestAbs out = ref_stanh(x, scale_a, scale_b) @@ -3664,6 +3670,16 @@ def init_shape(self): self.shape = [] +class TestSTanhComplex64(TestSTanh): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestSTanhComplex128(TestSTanh): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSTanhAPI(unittest.TestCase): # test paddle.nn.stanh def get_scale_a(self):