From cf84d9e00d84d34ec11ca7159f8d2b0f1297a63d Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 14:15:57 -0300 Subject: [PATCH 1/5] Make `lu` results have same storage mode as input --- lib/mps/linalg.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 1c29e75c0..8165cdcf5 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -118,7 +118,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = dev = device() queue = global_queue(dev) - At = MtlMatrix{T,PrivateStorage}(undef, (N, M)) + At = similar(A, (N, M)) mps_a = MPSMatrix(A) mps_at = MPSMatrix(At) @@ -137,7 +137,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = encode!(cbuf, kernel, mps_at, mps_at, mps_p, status) end - B = MtlMatrix{T}(undef, M, N) + B = similar(A, M, N) commit!(cmdbuf) do cbuf mps_b = MPSMatrix(B) @@ -176,7 +176,7 @@ end dev = device() queue = global_queue(dev) - At = MtlMatrix{T,PrivateStorage}(undef, (N, M)) + At = similar(A, (N, M)) mps_a = MPSMatrix(A) mps_at = MPSMatrix(At) From 36bebf7e951c679502ef7f791d63860482da1a76 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 14:41:19 -0300 Subject: [PATCH 2/5] Add tests --- test/mps/linalg.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index d0f982489..b840fe0d2 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -191,22 +191,29 @@ end end @testset "decompositions" begin + testreturntype(_,_) = false + testreturntype(luobj::LU{<:Any,<:MtlArray{<:Any,<:Any,S},<:MtlArray{<:Any,<:Any,S}},::MtlArray{<:Any,<:Any,S}) where S = true + A = MtlMatrix(rand(Float32, 1024, 1024)) lua = lu(A) + @test testreturntype(lua,A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 1024, 512)) lua = lu(A) + @test testreturntype(lua,A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 512, 1024)) lua = lu(A) + @test testreturntype(lua,A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A a = rand(Float32, 1024, 1024) A = MtlMatrix(a) B = MtlMatrix(a) lua = lu!(A) + @test testreturntype(lua,A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * B A = MtlMatrix{Float32}([1 2; 0 0]) From 86dfc2a33879e24e207ce92c508657b62165f175 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:06:52 -0300 Subject: [PATCH 3/5] Better tests --- test/mps/linalg.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index b840fe0d2..dc99f734c 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -190,30 +190,28 @@ end end end +using Metal: storagemode @testset "decompositions" begin - testreturntype(_,_) = false - testreturntype(luobj::LU{<:Any,<:MtlArray{<:Any,<:Any,S},<:MtlArray{<:Any,<:Any,S}},::MtlArray{<:Any,<:Any,S}) where S = true - A = MtlMatrix(rand(Float32, 1024, 1024)) lua = lu(A) - @test testreturntype(lua,A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 1024, 512)) lua = lu(A) - @test testreturntype(lua,A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 512, 1024)) lua = lu(A) - @test testreturntype(lua,A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A a = rand(Float32, 1024, 1024) A = MtlMatrix(a) B = MtlMatrix(a) lua = lu!(A) - @test testreturntype(lua,A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * B A = MtlMatrix{Float32}([1 2; 0 0]) From 7a9ed0e2125882238c4d31a16a315843d0cd6818 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:23:43 -0300 Subject: [PATCH 4/5] Actually useful tests --- test/mps/linalg.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index dc99f734c..106d7669c 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -194,28 +194,31 @@ using Metal: storagemode @testset "decompositions" begin A = MtlMatrix(rand(Float32, 1024, 1024)) lua = lu(A) - @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 1024, 512)) lua = lu(A) - @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A A = MtlMatrix(rand(Float32, 512, 1024)) lua = lu(A) - @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * A a = rand(Float32, 1024, 1024) A = MtlMatrix(a) B = MtlMatrix(a) lua = lu!(A) - @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) @test lua.L * lua.U ≈ MtlMatrix(lua.P) * B A = MtlMatrix{Float32}([1 2; 0 0]) @test_throws SingularException lu(A) + + altStorage = Metal.DefaultStorageMode != Metal.PrivateStorage ? Metal.PrivateStorage : Metal.SharedStorage + A = MtlMatrix{Float32,altStorage}(rand(Float32, 1024, 1024)) + lua = lu(A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) + lua = lu!(A) + @test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A) end using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax From 654953f794caebdcf53cbafa945c0e1f8bd877a5 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 17:30:53 -0300 Subject: [PATCH 5/5] Fix code --- lib/mps/linalg.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 8165cdcf5..8558f756a 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -118,7 +118,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = dev = device() queue = global_queue(dev) - At = similar(A, (N, M)) + At = MtlMatrix{T,PrivateStorage}(undef, (N, M)) mps_a = MPSMatrix(A) mps_at = MPSMatrix(At) @@ -128,7 +128,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = encode!(cbuf, kernel, descriptor) end - P = MtlMatrix{UInt32}(undef, 1, min(N, M)) + P = similar(A, UInt32, 1, min(N, M)) status = MtlArray{MPSMatrixDecompositionStatus}(undef) commitAndContinue!(cmdbuf) do cbuf @@ -176,7 +176,7 @@ end dev = device() queue = global_queue(dev) - At = similar(A, (N, M)) + At = MtlMatrix{T,PrivateStorage}(undef, (N, M)) mps_a = MPSMatrix(A) mps_at = MPSMatrix(At) @@ -186,7 +186,7 @@ end encode!(cbuf, kernel, descriptor) end - P = MtlMatrix{UInt32}(undef, 1, min(N, M)) + P = similar(A, UInt32, 1, min(N, M)) status = MtlArray{MPSMatrixDecompositionStatus}(undef) commitAndContinue!(cmdbuf) do cbuf