Skip to content

Commit

Permalink
【complex op】No.12、14 add complex support for square & reciprocal (#60821
Browse files Browse the repository at this point in the history
)

* support complex

* fix
  • Loading branch information
zbt78 authored Jan 17, 2024
1 parent 8e568e1 commit a75312d
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 14 deletions.
6 changes: 4 additions & 2 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT
thrust::complex<T>(b.real, b.imag));
return a;
#else
a.real = a.real * b.real - a.imag * b.imag;
a.imag = a.imag * b.real + b.imag * a.real;
T r = a.real * b.real - a.imag * b.imag;
T i = a.imag * b.real + b.imag * a.real;
a.real = r;
a.imag = i;
return a;
#endif
}
Expand Down
11 changes: 8 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
Expand Down Expand Up @@ -364,7 +365,9 @@ PD_REGISTER_KERNEL(square_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(square_double_grad,
CPU,
ALL_LAYOUT,
Expand All @@ -373,7 +376,9 @@ PD_REGISTER_KERNEL(square_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_double_grad,
CPU,
Expand Down
14 changes: 11 additions & 3 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)
Expand Down Expand Up @@ -228,8 +228,16 @@ PD_REGISTER_KERNEL(expm1,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(square,
CPU,
ALL_LAYOUT,
phi::SquareKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
Expand Down
62 changes: 62 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,24 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct ReciprocalGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<ComplexType<T>>(-1) *
(out * out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

// 1st reverse grad
// y = cos(x)
// x --> y
Expand Down Expand Up @@ -704,6 +722,22 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SquareGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<ComplexType<T>>(2) * x.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -3220,6 +3254,20 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSquareGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> two = static_cast<ComplexType<T>>(2.0f);

// dx = dout * 2 * x
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * two * conj(x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
Expand Down Expand Up @@ -3268,6 +3316,20 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaReciprocalGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = -dout * out^2
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return -dout * conj(out * out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
Expand Down
11 changes: 8 additions & 3 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
Expand Down Expand Up @@ -431,7 +432,9 @@ PD_REGISTER_KERNEL(square_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(square_double_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -441,7 +444,9 @@ PD_REGISTER_KERNEL(square_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_double_grad,
GPU,
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)
Expand Down Expand Up @@ -285,7 +285,9 @@ PD_REGISTER_KERNEL(square,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
Expand Down
39 changes: 38 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,6 +3416,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(1, 2, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.reciprocal(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -3425,12 +3430,29 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.01, check_pir=True)
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(
['X'], 'Out', max_relative_error=0.03, check_pir=True
)
else:
self.check_grad(
['X'], 'Out', max_relative_error=0.01, check_pir=True
)

def test_check_output(self):
self.check_output(check_pir=True)


class TestReciprocal_Complex64(TestReciprocal):
def init_dtype(self):
self.dtype = np.complex64


class TestReciprocal_Complex128(TestReciprocal):
def init_dtype(self):
self.dtype = np.complex128


class TestReciprocal_ZeroDim(TestReciprocal):
def init_shape(self):
self.shape = []
Expand Down Expand Up @@ -3801,6 +3823,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.square(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -3816,6 +3843,16 @@ def test_check_output(self):
self.check_output(check_pir=True)


class TestSquare_Complex64(TestSquare):
def init_dtype(self):
self.dtype = np.complex64


class TestSquare_Complex128(TestSquare):
def init_dtype(self):
self.dtype = np.complex128


class TestSquare_ZeroDim(TestSquare):
def init_shape(self):
self.shape = []
Expand Down

0 comments on commit a75312d

Please sign in to comment.