Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement MPS and MPO #232

Merged
merged 72 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
530069f
Prototype `MPS`, `MPO`
mofeing Aug 9, 2024
823da0e
Implement `rand`, `adjoint`, `defaultorder`, `boundary`, `form` for `…
mofeing Sep 12, 2024
789f101
Implement conversion from `Product` to `MPS`, `MPO`
mofeing Sep 12, 2024
bf96569
Refactor `MPS`, `MPO` on top of new `Ansatz` type
mofeing Sep 16, 2024
8115965
Move `Chain` code to `AbstractAnsatz` and `MPS`
mofeing Nov 7, 2024
71feb81
Fix `sites` method for `MPS`
mofeing Sep 18, 2024
f89773a
Fix `inds` method for `MPS`
mofeing Sep 18, 2024
82d8183
Refactor `adapt_structure` method to support additional types
mofeing Nov 7, 2024
eafc6de
Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on …
mofeing Nov 7, 2024
80e9455
Refactor `ChainRules` methods on top of new types
mofeing Nov 7, 2024
c46ac90
Refactor `rand` for `MPS`, `MPO`
mofeing Sep 18, 2024
7d1d254
Refactor `Chain` tests on top of `MPS`, `MPO`
mofeing Sep 18, 2024
830c8f4
Try using more `@site_str` instead of `Site` in MPS tests
mofeing Sep 18, 2024
1e82cc9
Implement some `sites`, `inds` methods for `MPO`
mofeing Sep 18, 2024
2dbe92e
Try using more `@site_str` in MPO tests
mofeing Sep 18, 2024
8df6213
Fix typo in `mixed_canonize!`
mofeing Sep 18, 2024
f030054
Fix `truncate` tests on `MPS`
mofeing Sep 18, 2024
775c6af
Refactor some tests of `MPS` to simplify
mofeing Sep 18, 2024
bf48d50
Fix typo in `normalize!` on `MPS` method
mofeing Sep 18, 2024
6cd0ef8
Fix typo
mofeing Sep 18, 2024
ec68b2e
Deprecate `isleftcanonical`, `isrightcanonical` in favor of `isisometry`
mofeing Sep 18, 2024
bdd2e26
Fix `isleftcanonical`, `isrightcanonical` tests on boundary sites
mofeing Sep 19, 2024
afddb6d
Fix `evolve!` calls in tests
mofeing Sep 19, 2024
b9a5148
Refactor MPO tests
mofeing Sep 22, 2024
5e04993
Stop orthogonalization to index on `mixed_canonize!`
mofeing Sep 26, 2024
6f84deb
Aesthetic name fix
mofeing Sep 26, 2024
e630527
Stop using `IdDict` on Reactant extension
mofeing Sep 28, 2024
5020d20
Fix `create_result` on `MPS`, `MPO`
mofeing Sep 30, 2024
cb1c8f4
Refactor lattice generation in constructors of `Dense`, `Product`, `M…
mofeing Sep 30, 2024
be708b8
Implement an MPS method initializing the tensors to identity (copy-te…
Todorbsc Oct 14, 2024
1b977fa
move files
mofeing Nov 7, 2024
a38ed16
fix constructors
mofeing Nov 7, 2024
7c12dab
Document types
mofeing Nov 7, 2024
c445787
Remove unimplemented `evolve!` method
mofeing Nov 7, 2024
6476930
fix mutability of `MPO`
mofeing Nov 8, 2024
a0e48f7
Move `MPO` code to "MPS.jl" and refactor common code
mofeing Nov 8, 2024
f97d33c
document `MPS`, `MPO` constructors
mofeing Nov 8, 2024
19264cf
document `rand` on `MPS`, `MPO`
mofeing Nov 8, 2024
1ec32bb
move some docstrings to `AbstractAnsatz`
mofeing Nov 8, 2024
b69d744
Refactor `normalize!`
mofeing Nov 8, 2024
c35ea9f
Fix `defaultorder`
mofeing Nov 8, 2024
109c9e5
fix `normalize!`
mofeing Nov 8, 2024
fc9c48a
apply `isisometry` docstring suggestion by @starsfordummies
mofeing Nov 8, 2024
84f422e
add shortcut for `normalize!` with mixed canonization
mofeing Nov 8, 2024
9b82c00
fix MPS identity constructor test
mofeing Nov 8, 2024
0a9de7b
implement shortcut `Quantum` constructor for simple gates
mofeing Nov 8, 2024
c932b7a
fix test
mofeing Nov 8, 2024
ae66855
refactor exported names
mofeing Nov 8, 2024
acf064e
fix `mixed_canonize!` tests
mofeing Nov 8, 2024
dcde266
fix `canonize!`, `mixed_canonize!`
mofeing Nov 8, 2024
3f5f7f8
import missing symbols to tests
mofeing Nov 8, 2024
a421a62
fix field name of `MixedCanonical`
mofeing Nov 10, 2024
ace97f1
fix namespace clash with `truncate`
mofeing Nov 10, 2024
2b31aeb
fix `truncate!`
mofeing Nov 10, 2024
f0b4d31
fix tests
mofeing Nov 10, 2024
5fae38c
try fix `mixed_canonize!`, `normalize!`
mofeing Nov 10, 2024
837ba8e
fix keyword args of `simple_update!` call
mofeing Nov 10, 2024
8186bc7
comment
mofeing Nov 10, 2024
16c5cbc
fix `MPO` test
mofeing Nov 10, 2024
a1a8b33
more fixes
mofeing Nov 10, 2024
293dae9
fix test
mofeing Nov 11, 2024
f8d4a91
refactor legacy `simple_update!` on `Canonical` form
mofeing Nov 11, 2024
de031b2
fix symbol in test
mofeing Nov 11, 2024
f38876d
rename testset
mofeing Nov 11, 2024
d97d154
fix `reindex!`
mofeing Nov 11, 2024
3b1bcc4
Remove legacy `@show`
mofeing Nov 11, 2024
ddbf909
format code
mofeing Nov 11, 2024
71defdb
refactor `evolve!` tests
mofeing Nov 11, 2024
fff1cb8
refactor `evolve!` tests again
mofeing Nov 11, 2024
c725495
fix indexing in `simple_update!`
mofeing Nov 11, 2024
c248c52
fix wrong call to `canonize!`
mofeing Nov 11, 2024
c9d3f52
try fix forward-mode diff of `MPS`, `MPO` constructors
mofeing Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tens
Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x))
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x))

end
2 changes: 2 additions & 0 deletions ext/TenetChainRulesCoreExt/frules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ end

# `AbstractAnsatz`-subtype constructors
ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, form=NoTangent())
ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, form=NoTangent())

# `Base.conj` methods
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ)
Expand Down
8 changes: 8 additions & 0 deletions ext/TenetChainRulesCoreExt/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ Product_pullback(ȳ) = (NoTangent(), ȳ.tn)
Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback

MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback

MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback

# `Base.conj` methods
conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ))
conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ))
Expand Down
5 changes: 5 additions & 0 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz)
end

ChainRulesTestUtils.rand_tangent(::AbstractRNG, lattice::Tenet.Lattice) = NoTangent()
ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Lattice, msg=""; kwargs...) = true
ChainRulesTestUtils.test_approx(actual::Tenet.Lattice, expected::Tenet.Lattice, msg; kwargs...) = actual == expected

ChainRulesTestUtils.rand_tangent(::AbstractRNG, form::Tenet.Form) = NoTangent()
ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Form, msg=""; kwargs...) = true
ChainRulesTestUtils.test_approx(actual::Tenet.Form, expected::Tenet.Form, msg; kwargs...) = actual == expected

end
10 changes: 4 additions & 6 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ module TenetQuacExt
using Tenet
using Quac: Gate, Circuit, lanes, arraytype, Swap

# function Tenet.Dense(gate::Gate)
# return Tenet.Dense(
# Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...]
# )
# end
function Tenet.Quantum(gate::Gate)
return Tenet.Quantum(arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...])
end

# Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...)
Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Quantum(gate); kwargs...)

function Tenet.Quantum(circuit::Circuit)
n = lanes(circuit)
Expand Down
14 changes: 14 additions & 0 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac
return Tenet.Product(tracetn)
end

for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn, form(prev))
end
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores)
return :($Tensor($data, $(inds(tocopy))))
Expand Down Expand Up @@ -70,6 +77,13 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu
return :($(Tenet.Product)($tn))
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end
end

# TODO try rely on generic fallback for ansatzes
# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
Expand Down
202 changes: 67 additions & 135 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct NonCanonical <: Form end
[`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form.
"""
struct MixedCanonical <: Form
orthogonality_center::Union{Site,Vector{Site}}
orthog_center::Union{Site,Vector{Site}}
end

"""
Expand Down Expand Up @@ -192,15 +192,52 @@ Contract the virtual bond between two [`Site`](@ref)s in a [`AbstractAnsatz`](@r
"""
@kwmethod contract!(tn::AbstractAnsatz; bond) = contract!(tn, inds(tn; bond))

"""
canonize!(tn::AbstractAnsatz)

Transform an [`AbstractAnsatz`](@ref) Tensor Network into the canonical form (aka Vidal gauge); i.e. the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ.
"""
function canonize! end

"""
canonize(tn::AbstractAnsatz)

Like [`canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
canonize(tn::AbstractAnsatz, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...)

"""
mixed_canonize!(tn::AbstractAnsatz, orthog_center)

Transform an [`AbstractAnsatz`](@ref) Tensor Network into the mixed-canonical form, that is,
for `i < orthog_center` the tensors are left-canonical and for `i >= orthog_center` the tensors are right-canonical,
and in the `orthog_center` there is a tensor with the Schmidt coefficients in it.
"""
function mixed_canonize! end

"""
mixed_canonize(tn::AbstractAnsatz, orthog_center)

Like [`mixed_canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
mixed_canonize(tn::AbstractAnsatz, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...)

canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)

"""
isisometry(tn::AbstractAnsatz, site; dir, kwargs...)

Check if the tensor at a given [`Site`](@ref) in a [`AbstractAnsatz`](@ref) Tensor Network is an isometry.
mofeing marked this conversation as resolved.
Show resolved Hide resolved
The `dir` keyword argument specifies the direction of the isometry to check.
"""
function isisometry end

"""
truncate(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing)

Like [`truncate!`](@ref), but returns a new tensor network instead of modifying the original one.
Like [`truncate!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)
Base.truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)

"""
truncate!(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing)
Expand Down Expand Up @@ -236,15 +273,21 @@ Truncate the dimension of the virtual `bond` of a [`NonCanonical`](@ref) Tensor
- `compute_local_svd`: Whether to compute the local SVD of the bond. If `true`, it will contract the bond and perform a SVD to get the local singular values. Defaults to `true`.
"""
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true)
virtualind = inds(tn; bond)

if compute_local_svd
tₗ = tensors(tn; at=min(bond...))
tᵣ = tensors(tn; at=max(bond...))
contract!(tn; bond)
svd!(tn; virtualind=inds(tn; bond))

left_inds = filter(!=(virtualind), inds(tₗ))
right_inds = filter(!=(virtualind), inds(tᵣ))
svd!(tn; left_inds, right_inds, virtualind=virtualind)
end

spectrum = parent(tensors(tn; bond))
vind = inds(tn; bond)

maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim
maxdim = isnothing(maxdim) ? size(tn, virtualind) : maxdim

extent = if isnothing(threshold)
1:maxdim
Expand All @@ -254,7 +297,7 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
end - 1, maxdim)
end

slice!(tn, vind, extent)
slice!(tn, virtualind, extent)

return tn
end
Expand All @@ -268,7 +311,7 @@ end
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
# requires a sweep to recanonize the TN
return canonize!(tn, bond)
return canonize!(tn)
end

overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)'))
Expand Down Expand Up @@ -333,7 +376,7 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth

@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

return simple_update!(form(ψ), ψ, gate; kwargs...)
return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
end

# TODO a lot of problems with merging... maybe we shouldn't merge manually
Expand Down Expand Up @@ -376,9 +419,19 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth
rinds = filter(!=(vind), inds(tensors(ψ; at=siter)))
contract!(ψ; bond)

# TODO replace for `merge!` when #243 is fixed
# reindex contracting indices to temporary names to avoid issues
oinds = Dict(site => inds(ψ; at=site) for site in sites(gate; set=:outputs))
tmpinds = Dict(site => gensym(:tmp) for site in sites(gate; set=:inputs))
replace!(gate, [inds(gate; at=site) => i for (site, i) in tmpinds])
replace!(ψ, [inds(ψ; at=site') => i for (site, i) in tmpinds])

# NOTE `replace!` is getting confused when a index is already there even if it would be overriden
# TODO fix this to be handled in one call -> replace when #244 is fixed
replace!(gate, [inds(gate; at=site) => gensym() for (site, i) in oinds])
replace!(gate, [inds(gate; at=site) => i for (site, i) in oinds])

# contract physical inds with gate
@reindex! outputs(ψ) => outputs(gate) reset = false
@reindex! inputs(gate) => outputs(ψ) reset = false
merge!(ψ, gate; reset=false)
contract!(ψ, inds(gate; set=:inputs))

Expand All @@ -395,129 +448,8 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth
end

# TODO remove `renormalize` argument?
# TODO refactor code
# TODO optimize correctly -> avoid recanonization + use lateral Λs
function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
@assert nlanes(gate) == 2 "Only 2-site gates are supported currently"
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
gate = copy(gate)

bond = sitel, siter = minmax(sites(gate; set=:outputs)...)
left_inds::Vector{Symbol} = !isnothing(leftindex(ψ, sitel)) ? [leftindex(ψ, sitel)] : Symbol[]
right_inds::Vector{Symbol} = !isnothing(rightindex(ψ, siter)) ? [rightindex(ψ, siter)] : Symbol[]

virtualind::Symbol = inds(ψ; bond=bond)

contract_2sitewf!(ψ, bond)

# reindex contracting index
contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)]
replace!(
ψ,
map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index)
inds(ψ; at=site') => contracting_index
end,
)
replace!(
gate,
map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index)
inds(gate; at=site) => contracting_index
end,
)

# replace output indices of the gate for gensym indices
output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)]
replace!(
gate,
map(zip(sites(gate; set=:outputs), output_inds)) do (site, out)
inds(gate; at=site) => out
end,
)

# reindex output of gate to match TN sitemap
for site in sites(gate; set=:outputs)
if inds(ψ; at=site) != inds(gate; at=site)
replace!(gate, inds(gate; at=site) => inds(ψ; at=site))
end
end

# contract physical inds
merge!(ψ, gate)
contract!(ψ, contracting_inds)

# decompose using SVD
push!(left_inds, inds(ψ; at=sitel))
push!(right_inds, inds(ψ; at=siter))

unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind)

# truncate virtual index
if any(!isnothing, [threshold, maxdim])
truncate!(ψ, bond; threshold, maxdim)
renormalize && normalize!(tensors(ψ; between=bond))
end

return ψ
end

# TODO refactor code
"""
contract_2sitewf!(ψ::AbstractAnsatz, bond)

For a given [`AbstractAnsatz`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁,
where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ.
"""
function contract_2sitewf!(ψ::AbstractAnsatz, bond)
@assert form(ψ) == Canonical() "The tensor network must be in canonical form"

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false)

contract!(ψ, inds(ψ; bond=bond))

return ψ
end

# TODO refactor code
"""
unpack_2sitewf!(ψ::AbstractAnsatz, bond)

For a given [`AbstractAnsatz`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical
form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`.
"""
function unpack_2sitewf!(ψ::AbstractAnsatz, bond, left_inds, right_inds, virtualind)
@assert form(ψ) == Canonical() "The tensor network must be in canonical form"

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

# do svd of the θ tensor
θ = tensors(ψ; at=sitel)
U, s, Vt = svd(θ; left_inds, right_inds, virtualind)

# contract with the inverse of Λᵢ and Λᵢ₊₂
Γᵢ₋₁ =
isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=())
Γᵢ =
isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=())

delete!(ψ, θ)

push!(ψ, Γᵢ₋₁)
push!(ψ, s)
push!(ψ, Γᵢ)

return ψ
simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
return canonize!(ψ)
end
2 changes: 2 additions & 0 deletions src/Lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ end
Base.copy(lattice::Lattice) = Lattice(copy(lattice.mapping), copy(lattice.graph))
Base.:(==)(a::Lattice, b::Lattice) = a.mapping == b.mapping && a.graph == b.graph

# TODO these where needed by ChainRulesTestUtils, do we still need them?
Base.zero(::Type{Lattice}) = Lattice(BijectiveIdDict{Site,Int}(), zero(Graphs.SimpleGraph{Int}))
Base.zero(::Lattice) = zero(Lattice)

Graphs.is_directed(::Type{Lattice}) = false

function Graphs.vertices(lattice::Lattice)
Expand Down
Loading
Loading