diff --git a/Project.toml b/Project.toml index e437e885f..b84adab09 100644 --- a/Project.toml +++ b/Project.toml @@ -39,13 +39,14 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +LAPACK_jll = "51474c39-65e3-53ba-86ba-03b1b862ec14" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" [extensions] LinearSolveBandedMatricesExt = "BandedMatrices" -LinearSolveBLISExt = "blis_jll" +LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"] LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = "Enzyme" diff --git a/ext/LinearSolveBLISExt.jl b/ext/LinearSolveBLISExt.jl index b8c31b680..59304a968 100644 --- a/ext/LinearSolveBLISExt.jl +++ b/ext/LinearSolveBLISExt.jl @@ -2,15 +2,20 @@ module LinearSolveBLISExt using Libdl using blis_jll +using LAPACK_jll using LinearAlgebra using LinearSolve -using LinearAlgebra: BlasInt, LU +using LinearAlgebra: libblastrampoline, BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, @blasfunc, chkargsok using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase const global libblis = blis_jll.blis +const global liblapack = libblastrampoline + +BLAS.lbt_forward(libblis; clear=true, verbose=true, suffix_hint="64_") +BLAS.lbt_forward(LAPACK_jll.liblapack_path; suffix_hint="64_", verbose=true) function getrf!(A::AbstractMatrix{<:ComplexF64}; ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), @@ -24,7 +29,7 @@ function getrf!(A::AbstractMatrix{<:ComplexF64}; if isempty(ipiv) ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) end - ccall((@blasfunc(zgetrf_), libblis), Cvoid, + ccall((@blasfunc(zgetrf_), liblapack), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) @@ -44,7 +49,7 @@ function getrf!(A::AbstractMatrix{<:ComplexF32}; if isempty(ipiv) ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) end - ccall((@blasfunc(cgetrf_), libblis), Cvoid, + ccall((@blasfunc(cgetrf_), liblapack), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) @@ -64,7 +69,7 @@ function getrf!(A::AbstractMatrix{<:Float64}; if isempty(ipiv) ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) end - ccall((@blasfunc(dgetrf_), libblis), Cvoid, + ccall((@blasfunc(dgetrf_), liblapack), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) @@ -84,7 +89,7 @@ function getrf!(A::AbstractMatrix{<:Float32}; if isempty(ipiv) ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) end - ccall((@blasfunc(sgetrf_), libblis), Cvoid, + ccall((@blasfunc(sgetrf_), liblapack), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) @@ -108,7 +113,7 @@ function getrs!(trans::AbstractChar, throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) end nrhs = size(B, 2) - ccall(("zgetrs_", libblis), Cvoid, + ccall(("zgetrs_", liblapack), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, @@ -133,7 +138,7 @@ function getrs!(trans::AbstractChar, throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) end nrhs = size(B, 2) - ccall(("cgetrs_", libblis), Cvoid, + ccall(("cgetrs_", liblapack), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, @@ -158,7 +163,7 @@ function getrs!(trans::AbstractChar, throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) end nrhs = size(B, 2) - ccall(("dgetrs_", libblis), Cvoid, + ccall(("dgetrs_", liblapack), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, @@ -183,7 +188,7 @@ function getrs!(trans::AbstractChar, throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) end nrhs = size(B, 2) - ccall(("sgetrs_", libblis), Cvoid, + ccall(("sgetrs_", liblapack), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,