diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 8558f756..dd70f5f2 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -109,8 +109,10 @@ end # Metal's pivoting sequence needs to be iterated sequentially... # TODO: figure out a GPU-compatible way to get the permutation matrix -LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = +LinearAlgebra.ipiv2perm(v::MtlVector, maxi::Integer) = LinearAlgebra.ipiv2perm(Array(v), maxi) +LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) = + LinearAlgebra.ipiv2perm(unsafe_wrap(Array, v), maxi) @autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool=true) where {T<:MtlFloat} @@ -129,7 +131,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = end P = similar(A, UInt32, 1, min(N, M)) - status = MtlArray{MPSMatrixDecompositionStatus}(undef) + status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef) commitAndContinue!(cmdbuf) do cbuf mps_p = MPSMatrix(P) @@ -150,7 +152,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = wait_completed(cmdbuf) - status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[]) + status = convert(LinearAlgebra.BlasInt, status[]) check && checknonsingular(status) return LinearAlgebra.LU(B, p, status) @@ -187,7 +189,7 @@ end end P = similar(A, UInt32, 1, min(N, M)) - status = MtlArray{MPSMatrixDecompositionStatus}(undef) + status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef) commitAndContinue!(cmdbuf) do cbuf mps_p = MPSMatrix(P) @@ -205,7 +207,7 @@ end wait_completed(cmdbuf) - status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[]) + status = convert(LinearAlgebra.BlasInt, status[]) check && _check_lu_success(status, allowsingular) return LinearAlgebra.LU(A, p, status)