Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】 No.10 add complex support for exp/expm1 #56398

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,29 @@ HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#endif
}

template <typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前的pr应该已经加了exp, 可以同步一下最新的代码,不要加重复了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

HOSTDEVICE inline complex<T> exp(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::exp(thrust::complex<T>(a)));
#else
return complex<T>(std::exp(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> expm1(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
// thrust does not support expm1
return complex<T>(thrust::exp(thrust::complex<T>(a)) - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 可能需要转成 complex 类型。另外 expm1 的实现或许也可以放到 activation_functor.h 里

#else
// expm1 in C++ does not support complex types
return complex<T>(std::exp(std::complex<T>(a)) -
static_cast<std::complex<T>>(1));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> conj(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,19 @@ PD_REGISTER_KERNEL(exp_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(expm1_grad,
CPU,
ALL_LAYOUT,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(
logit_grad, CPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {}
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ PD_REGISTER_KERNEL(exp,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(expm1,
CPU,
Expand All @@ -221,7 +223,9 @@ PD_REGISTER_KERNEL(expm1,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
Expand Down
75 changes: 75 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ struct Cosine<dtype::bfloat16> {
}
};

template <typename T>
struct Exp {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块或许并不需要显示地定义 Exp ?

HOSTDEVICE T operator()(const T& val) const { return exp(val); }
};

template <>
struct Exp<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(exp(static_cast<float>(val)));
}
};

template <>
struct Exp<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(exp(static_cast<float>(val)));
}
};

template <typename T>
using ComplexType = phi::dtype::complex<T>;

Expand Down Expand Up @@ -1167,6 +1186,22 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct ExpGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 (Device d, X x UNUSED, Out out, dOut dout, dX dx)

dx.device(d) =
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只需要求 out 的共轭就行了吧

}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -1194,6 +1229,22 @@ struct Expm1GradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct Expm1GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 (Device d, X x UNUSED, Out out, dOut dout, dX dx)

dx.device(d) =
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的求导应该不对

}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -2866,6 +2917,18 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaExpGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * exp(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return static_cast<ComplexType<T>>(dout * conj(out));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down Expand Up @@ -2918,6 +2981,18 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaExpm1GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * exp(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return static_cast<ComplexType<T>>(dout * conj(out));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ PD_REGISTER_KERNEL(exp_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
Expand All @@ -415,7 +417,9 @@ PD_REGISTER_KERNEL(expm1_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(square_grad,
GPU,
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ PD_REGISTER_KERNEL(exp,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(expm1,
GPU,
ALL_LAYOUT,
Expand All @@ -271,7 +274,10 @@ PD_REGISTER_KERNEL(expm1,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(square,
GPU,
ALL_LAYOUT,
Expand Down
11 changes: 10 additions & 1 deletion python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,16 @@ def expm1(x, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该方法的 docstring 也相应修改一下

'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'expm1',
)
helper = LayerHelper('expm1', **locals())
Expand Down
Loading