Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】No.12、14 add complex support for square & reciprocal #60821

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT
thrust::complex<T>(b.real, b.imag));
return a;
#else
a.real = a.real * b.real - a.imag * b.imag;
a.imag = a.imag * b.real + b.imag * a.real;
T r = a.real * b.real - a.imag * b.imag;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着之前的是有点问题,你是否找一些测试案例,测试一下这种情况?看*=能否产生正确的结果,如果不能请展示一下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是我学习cumprod的代码时发现的,在cpu端梯度反向传播时会进行复数的*=操作,然后产生错误。比如说这张图中,x_grad[0][0]的梯度应该是conj(1 + x[1][0]) = conj(1 + 2 + 3j)=3-3j,但是这里用了*=,相应的计算逻辑是conj(1 + 1 *= x[1][0])=conj(1 + [(1 * 2 - 0*3) + (0*2+3*2)j])=conj(3+6j)。应该就是因为在计算a.imag时使用了新的a.real导致的。

T i = a.imag * b.real + b.imag * a.real;
a.real = r;
a.imag = i;
return a;
#endif
}
Expand Down
11 changes: 8 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
Expand Down Expand Up @@ -364,7 +365,9 @@ PD_REGISTER_KERNEL(square_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(square_double_grad,
CPU,
ALL_LAYOUT,
Expand All @@ -373,7 +376,9 @@ PD_REGISTER_KERNEL(square_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_double_grad,
CPU,
Expand Down
14 changes: 11 additions & 3 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)
Expand Down Expand Up @@ -228,8 +228,16 @@ PD_REGISTER_KERNEL(expm1,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(square,
CPU,
ALL_LAYOUT,
phi::SquareKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
Expand Down
62 changes: 62 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,24 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct ReciprocalGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<ComplexType<T>>(-1) *
(out * out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

// 1st reverse grad
// y = cos(x)
// x --> y
Expand Down Expand Up @@ -704,6 +722,22 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SquareGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<ComplexType<T>>(2) * x.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -3220,6 +3254,20 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSquareGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> two = static_cast<ComplexType<T>>(2.0f);

// dx = dout * 2 * x
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * two * conj(x));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
Expand Down Expand Up @@ -3268,6 +3316,20 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaReciprocalGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = -dout * out^2
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return -dout * conj(out * out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
Expand Down
11 changes: 8 additions & 3 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
Expand Down Expand Up @@ -431,7 +432,9 @@ PD_REGISTER_KERNEL(square_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(square_double_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -441,7 +444,9 @@ PD_REGISTER_KERNEL(square_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_double_grad,
GPU,
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)
Expand Down Expand Up @@ -285,7 +285,9 @@ PD_REGISTER_KERNEL(square,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
Expand Down
39 changes: 38 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3414,6 +3414,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(1, 2, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.reciprocal(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -3423,12 +3428,29 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.01, check_pir=True)
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(
['X'], 'Out', max_relative_error=0.03, check_pir=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看这里对于复数类型的进行了特判,将设置max_relatice_error=0.03,绝对误差相差多少呢,max_relatice_error=0.02能否通过呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我在本机测试的是0.02多一点,所以扩大到0.03了

)
else:
self.check_grad(
['X'], 'Out', max_relative_error=0.01, check_pir=True
)

def test_check_output(self):
self.check_output(check_pir=True)


class TestReciprocal_Complex64(TestReciprocal):
def init_dtype(self):
self.dtype = np.complex64


class TestReciprocal_Complex128(TestReciprocal):
def init_dtype(self):
self.dtype = np.complex128


class TestReciprocal_ZeroDim(TestReciprocal):
def init_shape(self):
self.shape = []
Expand Down Expand Up @@ -3799,6 +3821,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.square(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -3814,6 +3841,16 @@ def test_check_output(self):
self.check_output(check_pir=True)


class TestSquare_Complex64(TestSquare):
def init_dtype(self):
self.dtype = np.complex64


class TestSquare_Complex128(TestSquare):
def init_dtype(self):
self.dtype = np.complex128


class TestSquare_ZeroDim(TestSquare):
def init_shape(self):
self.shape = []
Expand Down