Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】No.19 add complex support for triangular_solve #59529

Merged
merged 14 commits into from
Jan 3, 2024
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/blas/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ struct CBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> alpha,
const phi::dtype::complex<float> *A,
const int lda,
phi::dtype::complex<double> *B,
phi::dtype::complex<float> *B,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么改动了,有看过添加这个pr吗,为什么之前要用phi::dtype::complex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数调用的是cblas中用于解方程组的单精度复数版本的 cblas_ctrsm,
同时下面也有调用cblas中用于解方程组的双精度复数版本的cblas_ztrsm,
所以说A,B的类型应该是一样的。个人感觉这里应该是笔误。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好,可以再帮忙review一下吗~

const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class MatrixReduceSumFunctor<T, CPUContext> {

template class MatrixReduceSumFunctor<float, CPUContext>;
template class MatrixReduceSumFunctor<double, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, CPUContext>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor<T, GPUContext> {

template class MatrixReduceSumFunctor<float, GPUContext>;
template class MatrixReduceSumFunctor<double, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, GPUContext>;

} // namespace funcs
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
14 changes: 10 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,9 +3186,9 @@ def triangular_solve(

Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
more batch dimensions. Its data type should be float32, float64, complex64, complex128.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
Expand Down Expand Up @@ -3227,10 +3227,16 @@ def triangular_solve(
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("triangular_solve", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'triangular_solve'
x,
'x',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'triangular_solve'
y,
'y',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)

Expand Down
Loading