-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op】No.贰〇 add complex support for tanhshrink #56277
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add testing cases for complex data type
got it~ |
self.x_np = np.random.uniform(10, 20, [10, 17]).astype( | ||
np.float64 | ||
) + 1j * np.random.uniform(10, 20, [10, 17]).astype(np.float64) | ||
self.place = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.place
在 TestTanhshrinkComplex
中应该是没有用到的
@@ -967,6 +967,40 @@ def test_errors(self): | |||
F.tanhshrink(x_fp16) | |||
|
|||
|
|||
class TestTanhshrinkComplex(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议添加一个继承 TestTanhshrink
类的 Case
else paddle.CPUPlace() | ||
) | ||
|
||
def test_complex64(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_complex64
和 def test_complex128
应该都没有在 GPU 上进行测试吧
@@ -1642,6 +1642,24 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> { | |||
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } | |||
}; | |||
|
|||
template <typename T> | |||
struct TanhShrinkGradFunctor<ComplexType<T>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TanhShrinkGradFunctor
只针对 CPU,还需要补充 CudaTanhShrinkGradFunctor
下的 complex 类型
Sorry to inform you that 42b21f3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR types
New features
PR changes
OPs
Description
suppot complex for tanhshrink, without changing in
python/paddle/tensor/ops.py
andtest/legacy_test/test_activation_op.py
#56145