diff --git a/Project.toml b/Project.toml index 48ec6c8..de1c132 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.17" +version = "0.7.18" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/factorizations/eig.jl b/src/factorizations/eig.jl index 82358f4..502ac10 100644 --- a/src/factorizations/eig.jl +++ b/src/factorizations/eig.jl @@ -18,9 +18,10 @@ using MatrixAlgebraKit: for f in [:default_eig_algorithm, :default_eigh_algorithm] @eval begin - function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...) - alg = $f(blocktype(arrayt); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) + function MatrixAlgebraKit.$f(::Type{<:AbstractBlockSparseMatrix}; kwargs...) + return BlockPermutedDiagonalAlgorithm() do block + return $f(block; kwargs...) + end end end end @@ -45,12 +46,23 @@ function MatrixAlgebraKit.check_input( return nothing end +function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T} + DV = Base.promote_op(f, A) + !isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}} + return DV +end +function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T} + DV = Base.promote_op(f, A) + !isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}} + return DV +end + for f in [:eig_full!, :eigh_full!] @eval begin function MatrixAlgebraKit.initialize_output( ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm ) - Td, Tv = fieldtypes(Base.promote_op($f, blocktype(A), typeof(alg.alg))) + Td, Tv = fieldtypes(output_type($f, blocktype(A))) D = similar(A, BlockType(Td)) V = similar(A, BlockType(Tv)) return (D, V) @@ -60,7 +72,9 @@ for f in [:eig_full!, :eigh_full!] ) check_input($f, A, (D, V)) for I in eachstoredblockdiagindex(A) - D[I], V[I] = $f(@view(A[I]), alg.alg) + block = @view!(A[I]) + block_alg = block_algorithm(alg, block) + D[I], V[I] = $f(block, block_alg) end for I in eachunstoredblockdiagindex(A) # TODO: Support setting `LinearAlgebra.I` directly, and/or @@ -72,19 +86,31 @@ for f in [:eig_full!, :eigh_full!] end end +function output_type(f::typeof(eig_vals!), A::Type{<:AbstractMatrix{T}}) where {T} + D = Base.promote_op(f, A) + !isconcretetype(D) && return AbstractVector{complex(T)} + return D +end +function output_type(f::typeof(eigh_vals!), A::Type{<:AbstractMatrix{T}}) where {T} + D = Base.promote_op(f, A) + !isconcretetype(D) && return AbstractVector{real(T)} + return D +end + for f in [:eig_vals!, :eigh_vals!] @eval begin function MatrixAlgebraKit.initialize_output( ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm ) - T = Base.promote_op($f, blocktype(A), typeof(alg.alg)) + T = output_type($f, blocktype(A)) return similar(A, BlockType(T), axes(A, 1)) end function MatrixAlgebraKit.$f( A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm ) for I in eachblockstoredindex(A) - D[I] = $f(@view!(A[I]), alg.alg) + block = @view!(A[I]) + D[I] = $f(block, block_algorithm(alg, block)) end return D end diff --git a/src/factorizations/lq.jl b/src/factorizations/lq.jl index 79a068c..5dc1301 100644 --- a/src/factorizations/lq.jl +++ b/src/factorizations/lq.jl @@ -3,8 +3,9 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_ function MatrixAlgebraKit.default_lq_algorithm( A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - alg = default_lq_algorithm(blocktype(A); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) + return BlockPermutedDiagonalAlgorithm() do block + return default_lq_algorithm(block; kwargs...) + end end function similar_output( @@ -58,8 +59,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output( - lq_compact!, @view!(A[bI]), alg.alg + lq_compact!, block, block_alg ) end @@ -105,8 +108,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output( - lq_full!, @view!(A[bI]), alg.alg + lq_full!, block, block_alg ) end @@ -154,7 +159,9 @@ function MatrixAlgebraKit.lq_compact!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol])) - lq′ = lq_compact!(@view!(A[bI]), lq, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + lq′ = lq_compact!(block, lq, block_alg) @assert lq === lq′ "lq_compact! might not be in-place" end @@ -183,7 +190,9 @@ function MatrixAlgebraKit.lq_full!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol])) - lq′ = lq_full!(@view!(A[bI]), lq, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + lq′ = lq_full!(block, lq, block_alg) @assert lq === lq′ "lq_full! might not be in-place" end diff --git a/src/factorizations/qr.jl b/src/factorizations/qr.jl index ceaaf57..10c40d8 100644 --- a/src/factorizations/qr.jl +++ b/src/factorizations/qr.jl @@ -2,10 +2,11 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full! function MatrixAlgebraKit.default_qr_algorithm( - A::Type{<:AbstractBlockSparseMatrix}; kwargs... + ::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - alg = default_qr_algorithm(blocktype(A); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) + return BlockPermutedDiagonalAlgorithm() do block + return default_qr_algorithm(block; kwargs...) + end end function similar_output( @@ -59,8 +60,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output( - qr_compact!, @view!(A[bI]), alg.alg + qr_compact!, block, block_alg ) end @@ -106,8 +109,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output( - qr_full!, @view!(A[bI]), alg.alg + qr_full!, block, block_alg ) end @@ -155,7 +160,9 @@ function MatrixAlgebraKit.qr_compact!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol])) - qr′ = qr_compact!(@view!(A[bI]), qr, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + qr′ = qr_compact!(block, qr, block_alg) @assert qr === qr′ "qr_compact! might not be in-place" end @@ -184,7 +191,9 @@ function MatrixAlgebraKit.qr_full!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol])) - qr′ = qr_full!(@view!(A[bI]), qr, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + qr′ = qr_full!(block, qr, block_alg) @assert qr === qr′ "qr_full! might not be in-place" end diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 19e5c6c..8c616d4 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -10,24 +10,26 @@ A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped a a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix. """ -struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <: - MatrixAlgebraKit.AbstractAlgorithm - alg::A +struct BlockPermutedDiagonalAlgorithm{F} <: MatrixAlgebraKit.AbstractAlgorithm + falg::F +end +function block_algorithm(alg::BlockPermutedDiagonalAlgorithm, a::AbstractMatrix) + return block_algorithm(alg, typeof(a)) +end +function block_algorithm(alg::BlockPermutedDiagonalAlgorithm, A::Type{<:AbstractMatrix}) + return alg.falg(A) end function MatrixAlgebraKit.default_svd_algorithm( - A::Type{<:AbstractBlockSparseMatrix}; kwargs... + ::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - alg = default_svd_algorithm(blocktype(A); kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) + return BlockPermutedDiagonalAlgorithm() do block + return default_svd_algorithm(block; kwargs...) + end end -function output_type( - ::typeof(svd_compact!), - A::Type{<:AbstractMatrix{T}}, - Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm}, -) where {T} - USVᴴ = Base.promote_op(svd_compact!, A, Alg) +function output_type(::typeof(svd_compact!), A::Type{<:AbstractMatrix{T}}) where {T} + USVᴴ = Base.promote_op(svd_compact!, A) !isconcretetype(USVᴴ) && return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}} return USVᴴ @@ -36,7 +38,7 @@ end function similar_output( ::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm ) - BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg))) + BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A))) U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1])) S = similar(A, BlockType(BS), S_axes) Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2))) @@ -81,8 +83,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output( - svd_compact!, @view!(A[bI]), alg.alg + svd_compact!, block, block_alg ) end @@ -140,8 +144,10 @@ function MatrixAlgebraKit.initialize_output( # allocate output for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output( - svd_full!, @view!(A[bI]), alg.alg + svd_full!, block, block_alg ) end @@ -196,7 +202,9 @@ function MatrixAlgebraKit.svd_compact!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) - usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + usvᴴ′ = svd_compact!(block, usvᴴ, block_alg) @assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place" end @@ -226,7 +234,9 @@ function MatrixAlgebraKit.svd_full!( for bI in eachblockstoredindex(A) brow, bcol = Tuple(bI) usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol])) - usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg) + block = @view!(A[bI]) + block_alg = block_algorithm(alg, block) + usvᴴ′ = svd_full!(block, usvᴴ, block_alg) @assert usvᴴ === usvᴴ′ "svd_full! might not be in-place" end diff --git a/test/Project.toml b/test/Project.toml index cdca12e..6b6cd9a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ DiagonalArrays = "0.3" GPUArraysCore = "0.2" JLArrays = "0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.2" +MatrixAlgebraKit = "0.2.5" Random = "1" SafeTestsets = "0.1" SparseArraysBase = "0.5.11" diff --git a/test/test_abstract_blocktype.jl b/test/test_abstract_blocktype.jl index b54779b..5b2ca51 100644 --- a/test/test_abstract_blocktype.jl +++ b/test/test_abstract_blocktype.jl @@ -2,6 +2,26 @@ using Adapt: adapt using BlockArrays: Block using BlockSparseArrays: BlockSparseMatrix, blockstoredlength using JLArrays: JLArray +using LinearAlgebra: hermitianpart, norm +using MatrixAlgebraKit: + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + isisometry, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc using SparseArraysBase: storedlength using Test: @test, @test_broken, @testset @@ -11,30 +31,104 @@ arrayts = (Array, JLArray) elt in elts dev = adapt(arrayt) + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) @test sprint(show, MIME"text/plain"(), a) isa String @test iszero(storedlength(a)) @test iszero(blockstoredlength(a)) + + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) a[Block(1, 1)] = dev(randn(elt, 2, 2)) - a[Block(2, 2)] = dev(randn(elt, 3, 3)) @test !iszero(a[Block(1, 1)]) @test a[Block(1, 1)] isa arrayt{elt,2} - @test !iszero(a[Block(2, 2)]) - @test a[Block(2, 2)] isa arrayt{elt,2} + @test iszero(a[Block(2, 2)]) + @test a[Block(2, 2)] isa Matrix{elt} @test iszero(a[Block(2, 1)]) @test a[Block(2, 1)] isa Matrix{elt} @test iszero(a[Block(1, 2)]) @test a[Block(1, 2)] isa Matrix{elt} + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + a′ = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) + a′[Block(2, 2)] = dev(randn(elt, 3, 3)) + b = copy(a) @test Array(b) ≈ Array(a) - b = a + a - @test Array(b) ≈ Array(a) + Array(a) + b = a + a′ + @test Array(b) ≈ Array(a) + Array(a′) b = 3a @test Array(b) ≈ 3Array(a) b = a * a @test Array(b) ≈ Array(a) * Array(a) + + b = a * a′ + @test Array(b) ≈ Array(a) * Array(a′) + @test norm(b) ≈ 0 + + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + for f in (eig_full, eig_trunc) + if arrayt === Array + d, v = f(a) + @test a * v ≈ v * d + else + @test_broken f(a) + end + end + if arrayt === Array + d = eig_vals(a) + @test sort(Vector(d); by=abs) ≈ sort(eig_vals(Matrix(a)); by=abs) + else + @test_broken eig_vals(a) + end + + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = dev(parent(hermitianpart(randn(elt, 2, 2)))) + for f in (eigh_full, eigh_trunc) + if arrayt === Array + d, v = f(a) + @test a * v ≈ v * d + else + @test_broken f(a) + end + end + if arrayt === Array + d = eigh_vals(a) + @test sort(Vector(d); by=abs) ≈ sort(eig_vals(Matrix(a)); by=abs) + else + @test_broken eigh_vals(a) + end + + a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3]) + a[Block(1, 1)] = dev(randn(elt, 2, 2)) + for f in (left_orth, left_polar, qr_compact, qr_full) + if arrayt === Array + u, c = f(a) + @test u * c ≈ a + @test isisometry(u; side=:left) + else + @test_broken f(a) + end + end + for f in (right_orth, right_polar, lq_compact, lq_full) + if arrayt === Array + c, u = f(a) + @test c * u ≈ a + @test isisometry(u; side=:right) + else + @test_broken f(a) + end + end + for f in (svd_compact, svd_full, svd_trunc) + if arrayt === Array + u, s, v = f(a) + @test u * s * v ≈ a + else + @test_broken f(a) + end + end end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index f162c81..d9fb629 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -474,25 +474,3 @@ end @test sort(diagview(D[Block(1, 1)]); by=abs, rev=true) ≈ D1[1:1] @test sort(diagview(D[Block(2, 2)]); by=abs, rev=true) ≈ D2[1:2] end - -@testset "Abstract block type" begin - arrayt = Array - elt = Float32 - dev = adapt(arrayt) - - a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, ([2, 3], [2, 3])) - a[Block(1, 1)] = dev(randn(elt, 2, 2)) - a[Block(2, 2)] = dev(randn(elt, 3, 3)) - @test_broken eig_full(a) - @test_broken eigh_full(a) - @test_broken svd_compact(a) - @test_broken svd_full(a) - @test_broken left_orth(a) - @test_broken right_orth(a) - @test_broken left_polar(a) - @test_broken right_polar(a) - @test_broken qr_compact(a) - @test_broken qr_full(a) - @test_broken lq_compact(a) - @test_broken lq_full(a) -end