diff --git a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc index 80b2015f7318a..95e96b6d7918c 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc index 6245eb9042640..68af8bc2b1e92 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.h b/paddle/phi/kernels/funcs/blas/blas_impl.h index ffafe15b8fcf2..b4ee437011f66 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.h @@ -877,7 +877,7 @@ struct CBlas> { const phi::dtype::complex alpha, const phi::dtype::complex *A, const int lda, - phi::dtype::complex *B, + phi::dtype::complex *B, const int ldb) { cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); } diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cc b/paddle/phi/kernels/funcs/matrix_reduce.cc index e20d98984eb5a..03bdc820abe07 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cc +++ b/paddle/phi/kernels/funcs/matrix_reduce.cc @@ -55,6 +55,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, CPUContext>; +template class MatrixReduceSumFunctor, CPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index f4305914c4171..39bb62a6bf303 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, GPUContext>; +template class MatrixReduceSumFunctor, GPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu index f7eaa48579794..67861b282529b 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu index 2a943fd0ac681..342b8e3885d7b 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 212125825d9b1..2557954057104 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3186,9 +3186,9 @@ def triangular_solve( Args: x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or - more batch dimensions. Its data type should be float32 or float64. + more batch dimensions. Its data type should be float32, float64, complex64, complex128. y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is - zero or more batch dimensions. Its data type should be float32 or float64. + zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128. upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True. transpose (bool, optional): whether `x` should be transposed before calculation. Default: False. @@ -3227,10 +3227,16 @@ def triangular_solve( inputs = {"X": [x], "Y": [y]} helper = LayerHelper("triangular_solve", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'triangular_solve' + x, + 'x', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) check_variable_and_dtype( - y, 'y', ['float32', 'float64'], 'triangular_solve' + y, + 'y', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index f3624b5332817..d4aecda8780ce 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -51,10 +51,23 @@ def setUp(self): self.python_api = paddle.tensor.linalg.triangular_solve self.config() - self.inputs = { - 'X': np.random.random(self.x_shape).astype(self.dtype), - 'Y': np.random.random(self.y_shape).astype(self.dtype), - } + if self.dtype is np.complex64 or self.dtype is np.complex128: + self.inputs = { + 'X': ( + np.random.random(self.x_shape) + + 1j * np.random.random(self.x_shape) + ).astype(self.dtype), + 'Y': ( + np.random.random(self.y_shape) + + 1j * np.random.random(self.y_shape) + ).astype(self.dtype), + } + else: + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': np.random.random(self.y_shape).astype(self.dtype), + } + self.attrs = { 'upper': self.upper, 'transpose': self.transpose, @@ -248,6 +261,485 @@ def set_output(self): self.output = np.matmul(np.linalg.inv(x), y) +# 3D(broadcast) + 3D complex64 +class TestTriangularSolveOpCp643b3(TestTriangularSolveOp): + """ + case 10 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D upper complex64 +class TestTriangularSolveOpCp6422Up(TestTriangularSolveOp): + """ + case 11 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + max_relative_error=0.02, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex64 +class TestTriangularSolveOpCp6423T(TestTriangularSolveOp): + """ + case 12 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D , test 'unitriangular' complex64 +class TestTriangularSolveOpCp6422Un(TestTriangularSolveOp): + """ + case 13 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 4D(broadcast) + 4D(broadcast) complex64 +class TestTriangularSolveOpCp644b4b(TestTriangularSolveOp): + """ + case 14 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + max_relative_error=0.008, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex64 +class TestTriangularSolveOpCp643b4bUp(TestTriangularSolveOp): + """ + case 15 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 5D complex64 +class TestTriangularSolveOpCp643b5(TestTriangularSolveOp): + """ + case 16 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 5D + 4D(broadcast) complex64 +class TestTriangularSolveOpCp6454b(TestTriangularSolveOp): + """ + case 17 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 3D complex128 +class TestTriangularSolveOpCp1283b3(TestTriangularSolveOp): + """ + case 18 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D upper complex128 +class TestTriangularSolveOpCp12822Up(TestTriangularSolveOp): + """ + case 19 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex128 +class TestTriangularSolveOpCp12823T(TestTriangularSolveOp): + """ + case 20 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D , test 'unitriangular' complex128 +class TestTriangularSolveOpCp12822Un(TestTriangularSolveOp): + """ + case 21 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + ) + + +# 4D(broadcast) + 4D(broadcast) complex128 +class TestTriangularSolveOpCp1284b4b(TestTriangularSolveOp): + """ + case 22 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex128 +class TestTriangularSolveOpCp1283b4bUp(TestTriangularSolveOp): + """ + case 23 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 5D complex128 +class TestTriangularSolveOpCp1283b5(TestTriangularSolveOp): + """ + case 24 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 5D + 4D(broadcast) complex128 +class TestTriangularSolveOpCp12854b(TestTriangularSolveOp): + """ + case 25 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + class TestTriangularSolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021)