Skip to content

Commit

Permalink
fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Nov 8, 2023
1 parent cda1638 commit 2e9afec
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
9 changes: 8 additions & 1 deletion python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,14 @@ def sqrt(x, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64'],
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'sqrt',
)
helper = LayerHelper('sqrt', **locals())
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,9 @@ def _dfs_grad_op(op_desc, fwd_op_desc=None):
has_infer_inplace = base.core.has_infer_inplace(op_desc.type())
has_grad_op_maker = base.core.has_grad_op_maker(op_desc.type())
has_infer_inplace_in_grad_descendants = False
if not has_grad_op_maker:
# the OP test doesn't support higher order grad
is_grad_op_desc = op_desc.type().endswith('_grad')
if not has_grad_op_maker or is_grad_op_desc:
has_infer_inplace_in_descendants = False
else:
# get grad_op_desc
Expand Down
29 changes: 17 additions & 12 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,21 +1523,26 @@ def test_check_grad(self):
if self.dtype == np.float16:
return
if self.dtype not in [np.complex64, np.complex128]:
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
else:
self.check_grad(
['X'],
'Out',
)
self.check_grad(
['X'],
'Out',
)

def test_check_output(self):
self.check_output(check_prim=True, check_pir=True, check_prim_pir=True)
if self.dtype not in [np.complex64, np.complex128]:
self.check_output(
check_prim=True, check_pir=True, check_prim_pir=True
)
else:
self.check_output()


class TestSqrtComplex64(TestSqrt):
Expand Down

0 comments on commit 2e9afec

Please sign in to comment.