From 4a9ed1ac7cf49dc29e6ff4623af46f30feb2a595 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 8 Sep 2023 09:41:19 +0000 Subject: [PATCH] No.10 add complex support for exp/expm1 --- .../phi/kernels/cpu/activation_grad_kernel.cc | 8 +- paddle/phi/kernels/cpu/activation_kernel.cc | 8 +- paddle/phi/kernels/funcs/activation_functor.h | 98 +++++++++++++++++++ .../phi/kernels/gpu/activation_grad_kernel.cu | 8 +- paddle/phi/kernels/gpu/activation_kernel.cu | 8 +- python/paddle/tensor/ops.py | 15 ++- test/legacy_test/test_activation_op.py | 57 +++++++++++ 7 files changed, 191 insertions(+), 11 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b3203332ec7d1..87acb27e02cc9 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -340,7 +340,9 @@ PD_REGISTER_KERNEL(exp_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1_grad, CPU, @@ -348,7 +350,9 @@ PD_REGISTER_KERNEL(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL( logit_grad, CPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 8a554470dea39..7e66a94d6c6c3 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -211,7 +211,9 @@ PD_REGISTER_KERNEL(exp, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1, CPU, @@ -221,7 +223,9 @@ PD_REGISTER_KERNEL(expm1, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6295ca14aa3ad..56fc3a6ea598b 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1167,6 +1167,33 @@ struct ExpGradFunctor : public BaseActivationFunctor { } }; +template +struct ExpGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out.unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct Expm1 {}; + +template +struct Expm1> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return exp(val) - static_cast>(1); + } +}; + // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { @@ -1178,6 +1205,15 @@ struct Expm1Functor : public BaseActivationFunctor { } }; +template +struct Expm1Functor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Expm1>()).eval(); + } +}; + template struct Expm1GradFunctor : public BaseActivationFunctor { template { } }; +template +struct Expm1GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out.unaryExpr(Conj()) + dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { @@ -2831,6 +2882,16 @@ struct CudaExpFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpFunctor> + : public BaseActivationFunctor> { + // exp(x) = exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType x) const { + return static_cast>(exp(x)); + } +}; + template struct CudaSeluFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { @@ -2907,6 +2968,20 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpGradFunctor> + : public BaseActivationFunctor> { + // dx = dout * exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return static_cast>(dout * conj(out)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaReciprocalFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -2947,6 +3022,15 @@ struct CudaExpm1Functor : public BaseActivationFunctor { } }; +template +struct CudaExpm1Functor> + : public BaseActivationFunctor> { + __device__ __forceinline__ ComplexType operator()( + const ComplexType x) const { + return static_cast>(Expm1>()(x)); + } +}; + template struct CudaExpm1GradFunctor : public BaseActivationFunctor { // dx = dout * out @@ -2959,6 +3043,20 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpm1GradFunctor> + : public BaseActivationFunctor> { + // dx = dout * exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return static_cast>(dout * conj(out) + dout); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index a0695935de1bc..480c11d177415 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -398,7 +398,9 @@ PD_REGISTER_KERNEL(exp_grad, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -415,7 +417,9 @@ PD_REGISTER_KERNEL(expm1_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 061a02f531538..52de7d3218326 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -261,7 +261,9 @@ PD_REGISTER_KERNEL(exp, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1, GPU, ALL_LAYOUT, @@ -271,7 +273,9 @@ PD_REGISTER_KERNEL(expm1, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square, GPU, ALL_LAYOUT, diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index e15bf17beb646..1cc8437ce4f1c 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -567,7 +567,7 @@ def exp(x, name=None): out = e^x Args: - x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, float32, float64 or float16. + x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, float16, float32, float64, complex64 or complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -617,7 +617,7 @@ def expm1(x, name=None): out = e^x - 1 Args: - x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, float32, float64 or float16. + x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, float16, float32, float64, complex64 or complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -640,7 +640,16 @@ def expm1(x, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'int32', + 'int64', + 'complex64', + 'complex128', + ], 'expm1', ) helper = LayerHelper('expm1', **locals()) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index a7fe2cf3f602f..f40663fb10cfd 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -154,6 +154,48 @@ def init_shape(self): self.shape = [] +class TestExp_Complex64(OpTest): + def setUp(self): + self.op_type = "exp" + self.python_api = paddle.exp + self.public_python_api = paddle.exp + self.init_dtype() + self.init_shape() + self.if_enable_cinn() + np.random.seed(1024) + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) + out = np.exp(x) + self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} + self.outputs = {'Out': out} + self.convert_input_output() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.006) + + def init_dtype(self): + self.dtype = np.complex64 + + def init_shape(self): + self.shape = [10, 12] + + def if_enable_cinn(self): + pass + + def convert_input_output(self): + pass + + +class TestExp_Complex128(TestExp_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + class Test_Exp_Op_Fp16(unittest.TestCase): def test_api_fp16(self): with paddle.base.framework._static_guard(): @@ -192,6 +234,11 @@ def setUp(self): np.random.seed(2049) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = np.expm1(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -205,6 +252,16 @@ def test_check_output(self): self.check_output() +class TestExpm1_Complex64(TestExpm1): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestExpm1_Complex128(TestExpm1): + def init_dtype(self): + self.dtype = np.complex128 + + class TestExpm1_ZeroDim(TestExpm1): def init_shape(self): self.shape = []