Skip to content

Commit

Permalink
【complex op】No.19 add complex support for triangular_solve (#59529)
Browse files Browse the repository at this point in the history
  • Loading branch information
zbt78 authored Jan 3, 2024
1 parent d890019 commit 5d01382
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 13 deletions.
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/blas/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ struct CBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> alpha,
const phi::dtype::complex<float> *A,
const int lda,
phi::dtype::complex<double> *B,
phi::dtype::complex<float> *B,
const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class MatrixReduceSumFunctor<T, CPUContext> {

template class MatrixReduceSumFunctor<float, CPUContext>;
template class MatrixReduceSumFunctor<double, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, CPUContext>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor<T, GPUContext> {

template class MatrixReduceSumFunctor<float, GPUContext>;
template class MatrixReduceSumFunctor<double, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, GPUContext>;

} // namespace funcs
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
14 changes: 10 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3192,9 +3192,9 @@ def triangular_solve(
Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
more batch dimensions. Its data type should be float32, float64, complex64, complex128.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
Expand Down Expand Up @@ -3233,10 +3233,16 @@ def triangular_solve(
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("triangular_solve", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'triangular_solve'
x,
'x',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'triangular_solve'
y,
'y',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)

Expand Down
Loading

0 comments on commit 5d01382

Please sign in to comment.