diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 76e4dfb61..e4a9cf434 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -325,21 +325,21 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, return tn end -function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false) +function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; kwargs...) # move orthogonality center to bond mixed_canonize!(tn, bond) - return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize) + return truncate!(NonCanonical(), tn, bond; compute_local_svd=true, kwargs...) end """ - truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true) + truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...) Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest **Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`. """ -function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false) - truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize) +function truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...) + truncate!(NonCanonical(), tn, bond; compute_local_svd=false, kwargs...) canonize && canonize!(tn) diff --git a/src/MPS.jl b/src/MPS.jl index f96b2f826..c5dfd8ac8 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -131,9 +131,9 @@ end Check if the tensors in the mps are in the proper [`Form`](@ref). """ -check_form(mps::AbstractMPO) = check_form(form(mps), mps) +check_form(mps::AbstractMPO; kwargs...) = check_form(form(mps), mps; kwargs...) -function check_form(config::MixedCanonical, mps::AbstractMPO) +function check_form(config::MixedCanonical, mps::AbstractMPO; atol=1e-12) orthog_center = config.orthog_center left, right = if orthog_center isa Site @@ -144,23 +144,24 @@ function check_form(config::MixedCanonical, mps::AbstractMPO) for i in 1:nsites(mps) if i < left # Check left-canonical tensors - isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) + isisometry(mps, Site(i); dir=:right, atol) || throw(ArgumentError("Tensors are not left-canonical")) elseif i > right # Check right-canonical tensors - isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) + isisometry(mps, Site(i); dir=:left, atol) || throw(ArgumentError("Tensors are not right-canonical")) end end return true end -function check_form(::Canonical, mps::AbstractMPO) +function check_form(::Canonical, mps::AbstractMPO; atol=1e-12) for i in 1:nsites(mps) - if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) + if i > 1 && + !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right, atol) throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) end if i < nsites(mps) && - !isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) + !isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left, atol) throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) end end @@ -541,6 +542,133 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center) return tn end +""" + evolve!(ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true) + +Evolve the [`AbstractAnsatz`](@ref) `ψ` with the [`AbstractMPO`](@ref) `mpo` along the output indices of `ψ`. +If `threshold` or `maxdim` are not `nothing`, the tensors are truncated after each sweep at the proper value, and the +bond is normalized if `normalize=true`. If `reset_index=true`, the indices of the `ψ` are reset to the original ones. +""" +function evolve!( + ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true +) + original_sites = copy(Quantum(ψ).sites) + evolve!(form(ψ), ψ, mpo; threshold, maxdim, normalize) + + if reset_index + resetindex!(ψ; init=ninds(TensorNetwork(ψ)) + 1) + + replacements = [inds(ψ; at=site) => original_sites[site] for site in keys(original_sites)] + replace!(ψ, replacements) + end + + return ψ +end + +function evolve!(::NonCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...) + L = nsites(ψ) + Tenet.@reindex! outputs(ψ) => inputs(mpo) + + right_inds = [inds(ψ; at=Site(i), dir=:right) for i in 1:(L - 1)] + + for i in 1:L + contract_ind = inds(ψ; at=Site(i)) + push!(ψ, tensors(mpo; at=Site(i))) + contract!(ψ, contract_ind) + merge!(Quantum(ψ).sites, Dict(Site(i) => inds(mpo; at=Site(i)))) + end + + # Group the parallel bond indices + for i in 1:(L - 1) + groupinds!(ψ, right_inds[i]) + end + + if !isnothing(threshold) || !isnothing(maxdim) + truncate_sweep!(form(ψ), ψ; threshold, maxdim, normalize) + else + normalize && normalize!(ψ) + end + + return ψ +end + +function evolve!(::MixedCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; normalize, kwargs...) + initial_form = form(ψ) + mixed_canonize!(ψ, Site(nsites(ψ))) # We convert all the tensors to left-canonical form + + evolve!(NonCanonical(), ψ, mpo; normalize, kwargs...) + + mixed_canonize!(ψ, initial_form.orthog_center) + + return ψ +end + +function evolve!(::Canonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...) + # We first join the λs to the Γs to get MixedCanonical(Site(1)) form + for i in 1:(nsites(ψ) - 1) + contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right) + end + + evolve!(NonCanonical(), ψ, mpo; threshold=nothing, maxdim=nothing, normalize=false, kwargs...) # set maxdim and threshold to nothing so we truncate from Canonical form + + if !isnothing(threshold) || !isnothing(maxdim) + truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize) + else + normalize && canonize!(ψ; normalize) + end + + return ψ +end + +""" + truncate_sweep! + +Do a right-to-left QR sweep on the [`AbstractMPO`](@ref) `ψ` and then left-to-right SVD sweep and truncate the tensors +according to the `threshold` or `maxdim` values. The bond is normalized if `normalize=true`. +""" +function truncate_sweep! end + +function truncate_sweep!(::NonCanonical, ψ::AbstractMPO; threshold, maxdim, normalize) + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values and truncate + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + + (!isnothing(threshold) || !isnothing(maxdim)) && + truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false) + + contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right) + end + + ψ.form = MixedCanonical(Site(nsites(ψ))) + + return ψ +end + +function truncate_sweep!(::MixedCanonical, ψ::AbstractMPO; threshold, maxdim, normalize) + truncate_sweep!(NonCanonical(), ψ; threshold, maxdim, normalize) +end + +function truncate_sweep!(::Canonical, ψ::AbstractMPO; threshold, maxdim, normalize) + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values and truncate + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + (!isnothing(threshold) || !isnothing(maxdim)) && + truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false) + end + + canonize!(ψ) + + return ψ +end + LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...) LinearAlgebra.normalize!(ψ::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at) LinearAlgebra.normalize!(ψ::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond) @@ -564,14 +692,15 @@ function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=co end function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing) + old_norm = norm(ψ) if isnothing(bond) # Normalize all λ tensors for i in 1:(nsites(ψ) - 1) λ = tensors(ψ; between=(Site(i), Site(i + 1))) - replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1))) + replace!(ψ, λ => λ ./ old_norm^(1 / (nsites(ψ) - 1))) end else λ = tensors(ψ; between=bond) - replace!(ψ, λ => λ ./ norm(λ)) + replace!(ψ, λ => λ ./ old_norm) end return ψ diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 25e67fc88..b8359a4e9 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -374,6 +374,61 @@ using LinearAlgebra @test_throws ArgumentError Tenet.check_form(evolved) end end + + @testset "MPO evolution" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + normalize!(ψ) + mpo = rand(MPO; n=5, maxdim=8) + + ϕ_1 = deepcopy(ψ) + ϕ_2 = deepcopy(ψ) + ϕ_3 = deepcopy(ψ) + + @testset "NonCanonical" begin + evolve!(ϕ_1, mpo) + @test length(tensors(ϕ_1)) == 5 + @test norm(ϕ_1) ≈ 1.0 + + evolved = evolve!(deepcopy(ψ), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test norm(evolved) ≈ 1.0 + end + + @testset "Canonical" begin + canonize!(ϕ_2) + evolve!(ϕ_2, mpo) + @test length(tensors(ϕ_2)) == 5 + 4 + @test form(ϕ_2) == Canonical() + @test Tenet.check_form(ϕ_2) + + evolved = evolve!(deepcopy(canonize!(ψ)), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test form(evolved) == Canonical() + @test Tenet.check_form(evolved) + end + + @testset "MixedCanonical" begin + mixed_canonize!(ϕ_3, site"3") + evolve!(ϕ_3, mpo) + @test length(tensors(ϕ_3)) == 5 + @test form(ϕ_3) == MixedCanonical(Site(3)) + @test norm(ϕ_3) ≈ 1.0 + @test Tenet.check_form(ϕ_3) + + evolved = evolve!(deepcopy(mixed_canonize!(ψ, site"3")), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test form(evolved) == MixedCanonical(Site(3)) + @test norm(evolved) ≈ 1.0 + @test Tenet.check_form(evolved) + end + + t1 = contract(ϕ_1) + t2 = contract(ϕ_2) + t3 = contract(ϕ_3) + + @test t1 ≈ t2 ≈ t3 + @test only(overlap(ϕ_1, ϕ_2)) ≈ only(overlap(ϕ_1, ϕ_3)) ≈ only(overlap(ϕ_2, ϕ_3)) ≈ 1.0 + end end # TODO rename when method is renamed