From 4b8c356a535afd404e72dc6e2b83b33c362a0a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:48:54 +0100 Subject: [PATCH] Fix `truncate!` function (#254) * Fix truncate function and add small test * Fix mixed_canonize! function * Add form tests * Format code * Fix code * Fix orthog_center field in MixedCanonical form * Enhance mixed_canonize! tests * Add recanonize kwarg for truncate(::Canonical, ...) function * Small fixes on check_form functions * Small fixes on tests * Format code * Add comment --- src/Ansatz.jl | 32 +++++++++++++++------- src/MPS.jl | 36 ++++++++++++++++++------ test/MPS_test.jl | 71 +++++++++++++++++++++++++++++++++++++----------- 3 files changed, 104 insertions(+), 35 deletions(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index c7a96904d..ea28bd893 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -49,7 +49,7 @@ struct NonCanonical <: Form end left of the orthogonality center are left-canonical and the tensors to the right are right-canonical. """ struct MixedCanonical <: Form - orthog_center::Union{Site,Vector{Site}} + orthog_center::Union{Site,Vector{<:Site}} end """ @@ -255,8 +255,8 @@ Truncate the dimension of the virtual `bond`` of an [`Ansatz`](@ref) Tensor Netw - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. """ -function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) - return truncate!(form(tn), tn, bond; threshold, maxdim) +function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing, kwargs...) + return truncate!(form(tn), tn, bond; threshold, maxdim, kwargs...) end """ @@ -290,14 +290,18 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, spectrum = parent(tensors(tn; bond)) - maxdim = isnothing(maxdim) ? size(tn, virtualind) : maxdim + maxdim = isnothing(maxdim) ? size(tn, virtualind) : min(maxdim, length(spectrum)) extent = if isnothing(threshold) 1:maxdim else - 1:something(findfirst(1:maxdim) do i + # Find the first index where the condition is met + found_index = findfirst(1:maxdim) do i abs(spectrum[i]) < threshold - end - 1, maxdim) + end + + # If no index is found, return 1:length(spectrum), otherwise calculate the range + 1:(isnothing(found_index) ? maxdim : found_index - 1) end slice!(tn, virtualind, extent) @@ -308,13 +312,21 @@ end function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim) # move orthogonality center to bond mixed_canonize!(tn, bond) - return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) + return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true) end -function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim) +""" + truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true) + +Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest +**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`. +""" +function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true) truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) - # requires a sweep to recanonize the TN - return canonize!(tn) + + recanonize && canonize!(tn) + + return tn end overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) diff --git a/src/MPS.jl b/src/MPS.jl index 4ad94d9fa..e732c1af8 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -126,14 +126,26 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true) return mps end +""" + check_form(mps::AbstractMPO) + +Check if the tensors in the mps are in the proper [`Form`](@ref). +""" check_form(mps::AbstractMPO) = check_form(form(mps), mps) function check_form(config::MixedCanonical, mps::AbstractMPO) orthog_center = config.orthog_center + + left, right = if orthog_center isa Site + id(orthog_center) .+ (0, 0) # So left and right get the same value + elseif orthog_center isa Vector{<:Site} + extrema(id.(orthog_center)) + end + for i in 1:nsites(mps) - if i < id(orthog_center) # Check left-canonical tensors + if i < left # Check left-canonical tensors isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) - elseif i > id(orthog_center) # Check right-canonical tensors + elseif i > right # Check right-canonical tensors isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) end end @@ -143,8 +155,7 @@ end function check_form(::Canonical, mps::AbstractMPO) for i in 1:nsites(mps) - if i > 1 - !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) + if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) end @@ -157,6 +168,8 @@ function check_form(::Canonical, mps::AbstractMPO) return true end +check_form(::NonCanonical, mps::AbstractMPO) = true + """ MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) @@ -504,19 +517,24 @@ end # TODO dispatch on form # TODO generalize to AbstractAnsatz function mixed_canonize!(tn::AbstractMPO, orthog_center) + left, right = if orthog_center isa Site + id(orthog_center) .+ (-1, 1) + elseif orthog_center isa Vector{<:Site} + extrema(id.(orthog_center)) .+ (-1, 1) + else + throw(ArgumentError("`orthog_center` must be a `Site` or a `Vector{Site}`")) + end + # left-to-right QR sweep (left-canonical tensors) - for i in 1:(id(orthog_center) - 1) + for i in 1:left canonize_site!(tn, Site(i); direction=:right, method=:qr) end # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:(id(orthog_center) + 1) + for i in nsites(tn):-1:right canonize_site!(tn, Site(i); direction=:left, method=:qr) end - # center SVD sweep to get singular values - # canonize_site!(tn, orthog_center; direction=:left, method=:svd) - tn.form = MixedCanonical(orthog_center) return tn diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 86db308c7..6aa921f2d 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -95,15 +95,36 @@ using LinearAlgebra end @testset "truncate!" begin - ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - canonize_site!(ψ, Site(2); direction=:right, method=:svd) + @testset "NonCanonical" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(ψ, Site(2); direction=:right, method=:svd) + + truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + singular_values = tensors(ψ; between=(site"2", site"3")) + truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + # If maxdim > size(spectrum), the bond dimension is not truncated + truncated = truncate(ψ, [site"2", site"3"]; maxdim=4) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2 + end + + @testset "Canonical" begin + ψ = rand(MPS; n=5, maxdim=16) + canonize!(ψ) + + truncated = truncate(ψ, [site"2", site"3"]; maxdim=2) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2 + end - truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) - @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + @testset "MixedCanonical" begin + ψ = rand(MPS; n=5, maxdim=16) - singular_values = tensors(ψ; between=(site"2", site"3")) - truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) - @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + truncated = truncate(ψ, [site"2", site"3"]; maxdim=3) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 3 + end end @testset "norm" begin @@ -206,18 +227,36 @@ using LinearAlgebra end @testset "mixed_canonize!" begin - ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(ψ, site"3") + @testset "single Site" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, site"3") + @test Tenet.check_form(canonized) + + @test form(canonized) isa MixedCanonical + @test form(canonized).orthog_center == site"3" + + @test isisometry(canonized, site"1"; dir=:right) + @test isisometry(canonized, site"2"; dir=:right) + @test isisometry(canonized, site"4"; dir=:left) + @test isisometry(canonized, site"5"; dir=:left) - @test form(canonized) isa MixedCanonical - @test form(canonized).orthog_center == site"3" + @test contract(canonized) ≈ contract(ψ) + end + + @testset "multiple Sites" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, [site"2", site"3"]) - @test isisometry(canonized, site"1"; dir=:right) - @test isisometry(canonized, site"2"; dir=:right) - @test isisometry(canonized, site"4"; dir=:left) - @test isisometry(canonized, site"5"; dir=:left) + @test Tenet.check_form(canonized) + @test form(canonized) isa MixedCanonical + @test form(canonized).orthog_center == [site"2", site"3"] - @test contract(canonized) ≈ contract(ψ) + @test isisometry(canonized, site"1"; dir=:right) + @test isisometry(canonized, site"4"; dir=:left) + @test isisometry(canonized, site"5"; dir=:left) + + @test contract(canonized) ≈ contract(ψ) + end end @testset "expect" begin