From 41b41b52d21c2cfab2b92869ab7c5da362ab7cbc Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Mon, 25 Sep 2023 11:59:59 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Complex=20OP=E3=80=91No.30=20complex?= =?UTF-8?q?=20stanh=20op=20(#57639)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../phi/kernels/cpu/activation_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/activation_kernel.cc | 2 +- paddle/phi/kernels/funcs/activation_functor.h | 52 +++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/activation_kernel.cu | 2 +- test/legacy_test/test_activation_op.py | 18 ++++++- 6 files changed, 73 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index be44d968548f4..ee3e2b6b39e8b 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -303,7 +303,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index a2daac870c63e..9bee7c9f11365 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -197,7 +197,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index e5aaf1aeb8d34..f186d35428350 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -599,6 +599,32 @@ struct STanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct STanhGradFunctor> + : public BaseActivationFunctor> { + float scale_a; + float scale_b; + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto a = static_cast>(scale_a); // NOLINT + auto b = static_cast>(scale_b); + auto temp = (a * x).tanh() * (a * x).tanh(); + dx.device(d) = + dout * + (a * b * (static_cast>(1) - temp)).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Tangent { HOSTDEVICE T operator()(const T& val) const { return tan(val); } @@ -3578,6 +3604,32 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSTanhGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + float scale_a; + float scale_b; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; + } + + // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + ComplexType dout = static_cast>(arg_dout); + ComplexType x = static_cast>(arg_x); + ComplexType a = static_cast>(scale_a); + ComplexType b = static_cast>(scale_b); + ComplexType temp = tanh(a * x); + return static_cast>(dout * + conj(a * b * (one - temp * temp))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSoftplusFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 36d7d3ae1baf8..ff1552370a55c 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -379,7 +379,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, ThresholdedReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index a506415d36bab..a14f32599552a 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -246,7 +246,7 @@ PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 915a10fdc180b..caf2390d3fec7 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3694,7 +3694,13 @@ def setUp(self): scale_b = self.get_scale_b() np.random.seed(1024) - x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype is np.complex64 or self.dtype is np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) + else: + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) # The same reason with TestAbs out = ref_stanh(x, scale_a, scale_b) @@ -3724,6 +3730,16 @@ def init_shape(self): self.shape = [] +class TestSTanhComplex64(TestSTanh): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestSTanhComplex128(TestSTanh): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSTanhAPI(unittest.TestCase): # test paddle.nn.stanh def get_scale_a(self):