Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add evolve! for evolution of an MPS with an MPO #264

Merged
merged 18 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,21 +325,21 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
return tn
end

function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false)
function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; kwargs...)
# move orthogonality center to bond
mixed_canonize!(tn, bond)

return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize)
return truncate!(NonCanonical(), tn, bond; compute_local_svd=true, kwargs...)
end

"""
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true)
truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...)

Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize)
function truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...)
truncate!(NonCanonical(), tn, bond; compute_local_svd=false, kwargs...)

canonize && canonize!(tn)

Expand Down
147 changes: 138 additions & 9 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ end

Check if the tensors in the mps are in the proper [`Form`](@ref).
"""
check_form(mps::AbstractMPO) = check_form(form(mps), mps)
check_form(mps::AbstractMPO; kwargs...) = check_form(form(mps), mps; kwargs...)

function check_form(config::MixedCanonical, mps::AbstractMPO)
function check_form(config::MixedCanonical, mps::AbstractMPO; atol=1e-12)
orthog_center = config.orthog_center

left, right = if orthog_center isa Site
Expand All @@ -144,23 +144,24 @@ function check_form(config::MixedCanonical, mps::AbstractMPO)

for i in 1:nsites(mps)
if i < left # Check left-canonical tensors
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
isisometry(mps, Site(i); dir=:right, atol) || throw(ArgumentError("Tensors are not left-canonical"))
elseif i > right # Check right-canonical tensors
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
isisometry(mps, Site(i); dir=:left, atol) || throw(ArgumentError("Tensors are not right-canonical"))
end
end

return true
end

function check_form(::Canonical, mps::AbstractMPO)
function check_form(::Canonical, mps::AbstractMPO; atol=1e-12)
for i in 1:nsites(mps)
if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right)
if i > 1 &&
!isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right, atol)
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

if i < nsites(mps) &&
!isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left)
!isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left, atol)
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
end
Expand Down Expand Up @@ -541,6 +542,133 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center)
return tn
end

"""
evolve!(ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true)

Evolve the [`AbstractAnsatz`](@ref) `ψ` with the [`AbstractMPO`](@ref) `mpo` along the output indices of `ψ`.
If `threshold` or `maxdim` are not `nothing`, the tensors are truncated after each sweep at the proper value, and the
bond is normalized if `normalize=true`. If `reset_index=true`, the indices of the `ψ` are reset to the original ones.
"""
function evolve!(
ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true
)
original_sites = copy(Quantum(ψ).sites)
evolve!(form(ψ), ψ, mpo; threshold, maxdim, normalize)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@starsfordummies Tomorrow I will add here the replace! so this evolve! function does not change the output indices, after that, we can merge!


if reset_index
resetindex!(ψ; init=ninds(TensorNetwork(ψ)) + 1)

replacements = [inds(ψ; at=site) => original_sites[site] for site in keys(original_sites)]
replace!(ψ, replacements)
end

return ψ
end

function evolve!(::NonCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...)
L = nsites(ψ)
Tenet.@reindex! outputs(ψ) => inputs(mpo)

right_inds = [inds(ψ; at=Site(i), dir=:right) for i in 1:(L - 1)]

for i in 1:L
contract_ind = inds(ψ; at=Site(i))
push!(ψ, tensors(mpo; at=Site(i)))
contract!(ψ, contract_ind)
merge!(Quantum(ψ).sites, Dict(Site(i) => inds(mpo; at=Site(i))))
end

# Group the parallel bond indices
for i in 1:(L - 1)
groupinds!(ψ, right_inds[i])
end

if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(form(ψ), ψ; threshold, maxdim, normalize)
else
normalize && normalize!(ψ)
end

return ψ
end

function evolve!(::MixedCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; normalize, kwargs...)
initial_form = form(ψ)
mixed_canonize!(ψ, Site(nsites(ψ))) # We convert all the tensors to left-canonical form

evolve!(NonCanonical(), ψ, mpo; normalize, kwargs...)

mixed_canonize!(ψ, initial_form.orthog_center)

return ψ
end

function evolve!(::Canonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...)
# We first join the λs to the Γs to get MixedCanonical(Site(1)) form
for i in 1:(nsites(ψ) - 1)
contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right)
end

evolve!(NonCanonical(), ψ, mpo; threshold=nothing, maxdim=nothing, normalize=false, kwargs...) # set maxdim and threshold to nothing so we truncate from Canonical form

if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize)
else
normalize && canonize!(ψ; normalize)
end
Comment on lines +614 to +618
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to normalize if you don't truncate?
also, for this case (i.e. Canonical form) you will canonize 2 times if you truncate and normalize (1 in truncate_sweep for recanonization and another here for normalization)

Suggested change
if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize)
else
normalize && canonize!(ψ; normalize)
end
if !isnothing(threshold) || !isnothing(maxdim)
truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize)
normalize && canonize!(ψ; normalize)
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. This is not the case, you don't need to normalize if you already normalize the bond that you truncate.

And yes, you might want to normalize if you don't truncate if the mpo is not unitary and evolves the ψ out of norm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would you want to normalize here a TN that it's not already normalized? but ok, I see that both things could be understandable from a semantics point of view

let's merge, discuss on Monday and refactor it if we need


return ψ
end

"""
truncate_sweep!

Do a right-to-left QR sweep on the [`AbstractMPO`](@ref) `ψ` and then left-to-right SVD sweep and truncate the tensors
according to the `threshold` or `maxdim` values. The bond is normalized if `normalize=true`.
"""
function truncate_sweep! end

function truncate_sweep!(::NonCanonical, ψ::AbstractMPO; threshold, maxdim, normalize)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
for i in nsites(ψ):-1:2
canonize_site!(ψ, Site(i); direction=:left, method=:qr)
end

# left-to-right SVD sweep, get left-canonical tensors and singular values and truncate
for i in 1:(nsites(ψ) - 1)
canonize_site!(ψ, Site(i); direction=:right, method=:svd)

(!isnothing(threshold) || !isnothing(maxdim)) &&
truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false)

contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right)
end

ψ.form = MixedCanonical(Site(nsites(ψ)))

return ψ
end

function truncate_sweep!(::MixedCanonical, ψ::AbstractMPO; threshold, maxdim, normalize)
truncate_sweep!(NonCanonical(), ψ; threshold, maxdim, normalize)
end

function truncate_sweep!(::Canonical, ψ::AbstractMPO; threshold, maxdim, normalize)
for i in nsites(ψ):-1:2
canonize_site!(ψ, Site(i); direction=:left, method=:qr)
end

# left-to-right SVD sweep, get left-canonical tensors and singular values and truncate
for i in 1:(nsites(ψ) - 1)
canonize_site!(ψ, Site(i); direction=:right, method=:svd)
(!isnothing(threshold) || !isnothing(maxdim)) &&
truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false)
end

canonize!(ψ)

return ψ
end

LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...)
LinearAlgebra.normalize!(ψ::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at)
LinearAlgebra.normalize!(ψ::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond)
Expand All @@ -564,14 +692,15 @@ function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=co
end

function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing)
old_norm = norm(ψ)
if isnothing(bond) # Normalize all λ tensors
for i in 1:(nsites(ψ) - 1)
λ = tensors(ψ; between=(Site(i), Site(i + 1)))
replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1)))
replace!(ψ, λ => λ ./ old_norm^(1 / (nsites(ψ) - 1)))
end
else
λ = tensors(ψ; between=bond)
replace!(ψ, λ => λ ./ norm(λ))
replace!(ψ, λ => λ ./ old_norm)
Comment on lines +695 to +703
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in reality, there's no need for this because norm(λ) will only be computed once as it already is

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop, you need the norm of the full mps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already tried that but fails on some truncate tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh i see the difference now, I thought it was the same norm call

end

return ψ
Expand Down
55 changes: 55 additions & 0 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,61 @@ using LinearAlgebra
@test_throws ArgumentError Tenet.check_form(evolved)
end
end

@testset "MPO evolution" begin
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])
normalize!(ψ)
mpo = rand(MPO; n=5, maxdim=8)

ϕ_1 = deepcopy(ψ)
ϕ_2 = deepcopy(ψ)
ϕ_3 = deepcopy(ψ)

@testset "NonCanonical" begin
evolve!(ϕ_1, mpo)
@test length(tensors(ϕ_1)) == 5
@test norm(ϕ_1) ≈ 1.0

evolved = evolve!(deepcopy(ψ), mpo; maxdim=3)
@test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test norm(evolved) ≈ 1.0
end

@testset "Canonical" begin
canonize!(ϕ_2)
evolve!(ϕ_2, mpo)
@test length(tensors(ϕ_2)) == 5 + 4
@test form(ϕ_2) == Canonical()
@test Tenet.check_form(ϕ_2)

evolved = evolve!(deepcopy(canonize!(ψ)), mpo; maxdim=3)
@test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test form(evolved) == Canonical()
@test Tenet.check_form(evolved)
end

@testset "MixedCanonical" begin
mixed_canonize!(ϕ_3, site"3")
evolve!(ϕ_3, mpo)
@test length(tensors(ϕ_3)) == 5
@test form(ϕ_3) == MixedCanonical(Site(3))
@test norm(ϕ_3) ≈ 1.0
@test Tenet.check_form(ϕ_3)

evolved = evolve!(deepcopy(mixed_canonize!(ψ, site"3")), mpo; maxdim=3)
@test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...))
@test form(evolved) == MixedCanonical(Site(3))
@test norm(evolved) ≈ 1.0
@test Tenet.check_form(evolved)
end

t1 = contract(ϕ_1)
t2 = contract(ϕ_2)
t3 = contract(ϕ_3)

@test t1 ≈ t2 ≈ t3
@test only(overlap(ϕ_1, ϕ_2)) ≈ only(overlap(ϕ_1, ϕ_3)) ≈ only(overlap(ϕ_2, ϕ_3)) ≈ 1.0
end
end

# TODO rename when method is renamed
Expand Down
Loading