From 809466c01681aefe471f0d6e798f77f4ff76492c Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Tue, 12 Sep 2023 11:43:48 +0000 Subject: [PATCH 1/2] No.33 add complex support for sigmoid --- .../phi/kernels/cpu/activation_grad_kernel.cc | 8 +++-- paddle/phi/kernels/cpu/activation_kernel.cc | 2 +- paddle/phi/kernels/funcs/activation_functor.h | 35 +++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 8 +++-- paddle/phi/kernels/gpu/activation_kernel.cu | 2 +- python/paddle/tensor/ops.py | 14 ++++++-- test/legacy_test/test_activation_op.py | 28 +++++++++++++++ 7 files changed, 87 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 87acb27e02cc9..3dbc281dde599 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -423,9 +423,11 @@ PD_REGISTER_KERNEL(cos_triple_grad, phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_grad, SigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_double_grad, + SigmoidDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, + SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 7e66a94d6c6c3..11f3bee3ba0cb 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -231,7 +231,7 @@ PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} PD_REGISTER_KERNEL( square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) -PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 56fc3a6ea598b..d451a1d2acb5a 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1969,6 +1969,24 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { } }; +template +struct SigmoidGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + ComplexType one = static_cast>(1); + dx.device(d) = dout * (out * (one - out)).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + /* Out DOut -> SigmoidGradGrad -> DOutNew @@ -3981,6 +3999,23 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaSigmoidGradFunctor> + : public BaseActivationFunctor> { + using Complex = ComplexType; + Complex one = Complex(1.0f); + // dx = dout * out * (1 - out) + __device__ __forceinline__ Complex operator()(const Complex dout, + const Complex out) const { + Complex y = out * (one - out); + return dout * conj(y); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaLogSigmoidFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 480c11d177415..4cac923c29cbb 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -495,9 +495,11 @@ PD_REGISTER_KERNEL(cos_triple_grad, phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_grad, SigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_double_grad, + SigmoidDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, + SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 52de7d3218326..5c369dc39cc3c 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -293,7 +293,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) -PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 1cc8437ce4f1c..bf0f4d4beb1e5 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -828,7 +828,7 @@ def sigmoid(x, name=None): out = \\frac{1}{1 + e^{-x}} Args: - x (Tensor): Input of Sigmoid operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Sigmoid operator, an N-D Tensor, with data type float16, float32, float64, complex64 or complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -850,7 +850,17 @@ def sigmoid(x, name=None): return _C_ops.sigmoid(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'sigmoid' + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'uint16', + 'complex64', + 'complex128', + ], + 'sigmoid', ) helper = LayerHelper('sigmoid', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f40663fb10cfd..5fcee93934414 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -375,6 +375,34 @@ def test_check_grad(self): self.check_grad(['X'], 'Out', max_relative_error=0.01, check_prim=True) +class TestSigmoid_Complex64(TestSigmoid): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_output(self): + self.check_output(check_prim=False) + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.007, + check_prim=False, + ) + + +class TestSigmoid_Complex128(TestSigmoid_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + check_prim=False, + ) + + class TestSigmoid_ZeroDim(TestSigmoid): def init_shape(self): self.shape = [] From 63609d09739ef4ecdd9918f695ed635375581736 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 13 Sep 2023 07:25:17 +0000 Subject: [PATCH 2/2] fix: fix testcase --- test/legacy_test/test_activation_op.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 5fcee93934414..63c0ca8094ce6 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -356,6 +356,11 @@ def setUp(self): self.if_enable_cinn() np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = 1 / (1 + np.exp(-x)) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -386,7 +391,7 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', - max_relative_error=0.007, + max_relative_error=0.006, check_prim=False, )