From 8bcc691662d23cdf30f06bb8b67669aa0538f0d5 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Wed, 20 Sep 2023 07:48:08 +0000 Subject: [PATCH] feat: suppourt complex for softplus --- .../phi/kernels/cpu/activation_grad_kernel.cc | 7 ++- paddle/phi/kernels/cpu/activation_kernel.cc | 2 +- paddle/phi/kernels/funcs/activation_functor.h | 55 ++++++++++++++++++- .../phi/kernels/gpu/activation_grad_kernel.cu | 7 ++- paddle/phi/kernels/gpu/activation_kernel.cu | 2 +- python/paddle/nn/functional/activation.py | 14 ++++- test/legacy_test/test_activation_op.py | 18 ++++++ 7 files changed, 94 insertions(+), 11 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index d3cf1cbcb34c19..db5ef65e1c2a35 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -307,7 +307,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, ReluDoubleGradKernel) @@ -320,8 +321,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 66480018a52730..31b0d74b941749 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -201,7 +201,7 @@ PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, CPU, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6b77c31d38d4a1..84ca3e0ce88032 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -773,6 +773,31 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SoftplusGradFunctor> + : public BaseActivationFunctor> { + float beta; + float threshold; + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto x_beta = static_cast>(beta) * x; // NOLINT + dx.device(d) = + (x_beta > static_cast>(threshold)) + .select(dout, + dout / (static_cast>(1) + (-x_beta).exp()) + .unaryExpr(Conj())); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct SoftplusDoubleGradFunctor : public BaseActivationFunctor { float beta; @@ -3576,7 +3601,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); - MPType x_beta = x * beta; + MPType x_beta = x * static_cast(beta); return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; @@ -3606,6 +3631,34 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSoftplusGradFunctor> + : public BaseActivationFunctor> { + using MPType = typename phi::dtype::MPTypeTrait>::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * static_cast(beta); + return x_beta > t + ? dout + : static_cast>(dout / conj(one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index d592dfad0a52db..ccf0456ef90fe7 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -381,9 +381,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 000428268bbb14..7206832341bda6 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -250,7 +250,7 @@ PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, GPU, diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index e02a47d7bf8dda..baffc7678189c7 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1269,7 +1269,7 @@ def softplus(x, beta=1, threshold=20, name=None): \end{cases} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type float32, float64, complex64, complex128. beta (float, optional): The value of :math:`\beta` for softplus. Default is 1 threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20 name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. @@ -1294,7 +1294,17 @@ def softplus(x, beta=1, threshold=20, name=None): return _C_ops.softplus(x, beta, threshold) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'softplus' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], + 'softplus', ) helper = LayerHelper('softplus', **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 7da773d16327fe..cc2af9c25cc55e 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3760,6 +3760,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = ref_softplus(x, beta, threshold) self.inputs = {'X': x} self.attrs = {'beta': beta, "threshold": threshold} @@ -3774,6 +3779,19 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestSoftplus_Complex64(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.06) + + +class TestSoftplus_Complex128(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSoftplus_ZeroDim(TestSoftplus): def init_shape(self): self.shape = []