-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op】 No.10 add complex support for exp/expm1 #56398
Changes from 7 commits
0fd0b59
6c1a3ef
da55715
ed85ed9
077f534
11cf229
b43efdf
1d4e2da
4f09e7b
6eacc19
989f273
0210cee
4083ff1
9513cce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -466,6 +466,29 @@ HOSTDEVICE inline complex<T> tanh(const complex<T>& a) { | |
#endif | ||
} | ||
|
||
template <typename T> | ||
HOSTDEVICE inline complex<T> exp(const complex<T>& a) { | ||
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ | ||
(defined(__CUDA_ARCH__) || defined(__HIPCC__)) | ||
return complex<T>(thrust::exp(thrust::complex<T>(a))); | ||
#else | ||
return complex<T>(std::exp(std::complex<T>(a))); | ||
#endif | ||
} | ||
|
||
template <typename T> | ||
HOSTDEVICE inline complex<T> expm1(const complex<T>& a) { | ||
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ | ||
(defined(__CUDA_ARCH__) || defined(__HIPCC__)) | ||
// thrust does not support expm1 | ||
return complex<T>(thrust::exp(thrust::complex<T>(a)) - 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
#else | ||
// expm1 in C++ does not support complex types | ||
return complex<T>(std::exp(std::complex<T>(a)) - | ||
static_cast<std::complex<T>>(1)); | ||
#endif | ||
} | ||
|
||
template <typename T> | ||
HOSTDEVICE inline complex<T> conj(const complex<T>& a) { | ||
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,25 @@ struct Cosine<dtype::bfloat16> { | |
} | ||
}; | ||
|
||
template <typename T> | ||
struct Exp { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一块或许并不需要显示地定义 |
||
HOSTDEVICE T operator()(const T& val) const { return exp(val); } | ||
}; | ||
|
||
template <> | ||
struct Exp<dtype::float16> { | ||
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { | ||
return dtype::float16(exp(static_cast<float>(val))); | ||
} | ||
}; | ||
|
||
template <> | ||
struct Exp<dtype::bfloat16> { | ||
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { | ||
return dtype::bfloat16(exp(static_cast<float>(val))); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
using ComplexType = phi::dtype::complex<T>; | ||
|
||
|
@@ -1167,6 +1186,22 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> { | |
} | ||
}; | ||
|
||
template <typename T> | ||
struct ExpGradFunctor<ComplexType<T>> | ||
: public BaseActivationFunctor<ComplexType<T>> { | ||
template <typename Device, | ||
typename X, | ||
typename Out, | ||
typename dOut, | ||
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该是 |
||
dx.device(d) = | ||
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只需要求 |
||
} | ||
|
||
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } | ||
}; | ||
|
||
// expm1(x) = e^x - 1 | ||
template <typename T> | ||
struct Expm1Functor : public BaseActivationFunctor<T> { | ||
|
@@ -1194,6 +1229,22 @@ struct Expm1GradFunctor : public BaseActivationFunctor<T> { | |
} | ||
}; | ||
|
||
template <typename T> | ||
struct Expm1GradFunctor<ComplexType<T>> | ||
: public BaseActivationFunctor<ComplexType<T>> { | ||
template <typename Device, | ||
typename X, | ||
typename Out, | ||
typename dOut, | ||
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该是 |
||
dx.device(d) = | ||
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块的求导应该不对 |
||
} | ||
|
||
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } | ||
}; | ||
|
||
// relu(x) = max(x, 0) | ||
template <typename T> | ||
struct ReluCPUFunctor : public BaseActivationFunctor<T> { | ||
|
@@ -2866,6 +2917,18 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> { | |
} | ||
}; | ||
|
||
template <typename T> | ||
struct CudaExpGradFunctor<ComplexType<T>> | ||
: public BaseActivationFunctor<ComplexType<T>> { | ||
// dx = dout * exp(x) | ||
__device__ __forceinline__ ComplexType<T> operator()( | ||
const ComplexType<T> dout, const ComplexType<T> out) const { | ||
return static_cast<ComplexType<T>>(dout * conj(out)); | ||
} | ||
|
||
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } | ||
}; | ||
|
||
template <typename T> | ||
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { | ||
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; | ||
|
@@ -2918,6 +2981,18 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> { | |
} | ||
}; | ||
|
||
template <typename T> | ||
struct CudaExpm1GradFunctor<ComplexType<T>> | ||
: public BaseActivationFunctor<ComplexType<T>> { | ||
// dx = dout * exp(x) | ||
__device__ __forceinline__ ComplexType<T> operator()( | ||
const ComplexType<T> dout, const ComplexType<T> out) const { | ||
return static_cast<ComplexType<T>>(dout * conj(out)); | ||
} | ||
|
||
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } | ||
}; | ||
|
||
template <typename T> | ||
struct CudaSinFunctor : public BaseActivationFunctor<T> { | ||
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -640,7 +640,16 @@ def expm1(x, name=None): | |
check_variable_and_dtype( | ||
x, | ||
'x', | ||
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], | ||
[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该方法的 docstring 也相应修改一下 |
||
'float16', | ||
'uint16', | ||
'float32', | ||
'float64', | ||
'int32', | ||
'int64', | ||
'complex64', | ||
'complex128', | ||
], | ||
'expm1', | ||
) | ||
helper = LayerHelper('expm1', **locals()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里之前的pr应该已经加了exp, 可以同步一下最新的代码,不要加重复了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的