Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex No.15 sqrt op #57661

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

Expand Down
35 changes: 34 additions & 1 deletion paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,24 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct SqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<ComplexType<T>>(0.5) / out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -2740,7 +2758,6 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
Expand Down Expand Up @@ -3843,6 +3860,22 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaSqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one_half = static_cast<ComplexType<T>>(0.5f);

// dx = dout * 0.5 / out
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return dout * conj(one_half / out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

Expand Down
9 changes: 8 additions & 1 deletion python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,14 @@ def sqrt(x, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64'],
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'sqrt',
)
helper = LayerHelper('sqrt', **locals())
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,9 @@ def _dfs_grad_op(op_desc, fwd_op_desc=None):
has_infer_inplace = base.core.has_infer_inplace(op_desc.type())
has_grad_op_maker = base.core.has_grad_op_maker(op_desc.type())
has_infer_inplace_in_grad_descendants = False
if not has_grad_op_maker:
# the OP test doesn't support higher order grad
is_grad_op_desc = op_desc.type().endswith('_grad')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要修改这里的 op test呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为 OpTest 里应该是没有高阶导的测试,而这里的判断就避免 dfs 函数搜到高阶 grad op_desc。

if not has_grad_op_maker or is_grad_op_desc:
has_infer_inplace_in_descendants = False
else:
# get grad_op_desc
Expand Down
45 changes: 36 additions & 9 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,13 @@ def setUp(self):
self.if_enable_cinn()

np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype is np.complex64 or self.dtype is np.complex128:
x = (
np.random.uniform(0.1, 1, self.shape)
+ 1j * np.random.uniform(0.1, 1, self.shape)
).astype(self.dtype)
else:
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.sqrt(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -1516,16 +1522,37 @@ def if_enable_cinn(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
if self.dtype not in [np.complex64, np.complex128]:
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
else:
self.check_grad(
['X'],
'Out',
)

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)
if self.dtype not in [np.complex64, np.complex128]:
self.check_output(
check_prim=True, check_pir=True, check_prim_pir=True
)
else:
self.check_output()


class TestSqrtComplex64(TestSqrt):
def init_dtype(self):
self.dtype = np.complex64


class TestSqrtComplex128(TestSqrt):
def init_dtype(self):
self.dtype = np.complex128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看下这两个单测是否跑进去了,如果跑进去的话,贴一个截图上来



class TestSqrtPrimFp32(TestActivation):
Expand Down