From 54fde6d7110f4d53ad97f6a91ab909c4c94cf7cd Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 8 Nov 2023 13:26:36 +0800 Subject: [PATCH] fix --- test/legacy_test/test_activation_op.py | 29 +++++++++++++++----------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index b1d77a535ceb8c..9de59e4f3c4ef9 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -1523,21 +1523,26 @@ def test_check_grad(self): if self.dtype == np.float16: return if self.dtype not in [np.complex64, np.complex128]: - self.check_grad( - ['X'], - 'Out', - check_prim=True, - check_pir=True, - check_prim_pir=True, - ) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) else: - self.check_grad( - ['X'], - 'Out', - ) + self.check_grad( + ['X'], + 'Out', + ) def test_check_output(self): - self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) + if self.dtype not in [np.complex64, np.complex128]: + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) + else: + self.check_output() class TestSqrtComplex64(TestSqrt):