diff --git a/Project.toml b/Project.toml index c3d97358..12c8898f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.13.12" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" @@ -53,6 +54,7 @@ AbstractTrees = "0.4.4" Adapt = "4" Combinatorics = "1" Compat = "3, 4" +ConstructionBase = "1.6.0" DataGraphs = "0.2.3" DataStructures = "0.18" Dictionaries = "0.4" diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 0ca52a12..582a51e2 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -34,11 +34,7 @@ include("formnetworks/quadraticformnetwork.jl") include("gauging.jl") include("utils.jl") include("update_observer.jl") -include("solvers/local_solvers/eigsolve.jl") -include("solvers/local_solvers/exponentiate.jl") -include("solvers/local_solvers/dmrg_x.jl") -include("solvers/local_solvers/contract.jl") -include("solvers/local_solvers/linsolve.jl") + include("treetensornetworks/abstracttreetensornetwork.jl") include("treetensornetworks/treetensornetwork.jl") include("treetensornetworks/opsum_to_ttn/matelem.jl") @@ -48,18 +44,27 @@ include("treetensornetworks/projttns/abstractprojttn.jl") include("treetensornetworks/projttns/projttn.jl") include("treetensornetworks/projttns/projttnsum.jl") include("treetensornetworks/projttns/projouterprodttn.jl") -include("solvers/solver_utils.jl") -include("solvers/defaults.jl") -include("solvers/insert/insert.jl") -include("solvers/extract/extract.jl") -include("solvers/alternating_update/alternating_update.jl") -include("solvers/alternating_update/region_update.jl") -include("solvers/tdvp.jl") -include("solvers/dmrg.jl") -include("solvers/dmrg_x.jl") -include("solvers/contract.jl") -include("solvers/linsolve.jl") -include("solvers/sweep_plans/sweep_plans.jl") + +include("solvers/local_solvers/eigsolve.jl") +include("solvers/local_solvers/exponentiate.jl") +include("solvers/local_solvers/runge_kutta.jl") +include("solvers/truncation_parameters.jl") +include("solvers/iterators.jl") +include("solvers/adapters.jl") +include("solvers/sweep_solve.jl") +include("solvers/region_plans/dfs_plans.jl") +include("solvers/region_plans/euler_tour.jl") +include("solvers/region_plans/euler_plans.jl") +include("solvers/region_plans/tdvp_region_plans.jl") +include("solvers/extracter.jl") +include("solvers/inserter.jl") +include("solvers/subspace/subspace.jl") +include("solvers/subspace/densitymatrix.jl") +include("solvers/permute_indices.jl") +include("solvers/operator_map.jl") +include("solvers/eigsolve.jl") +include("solvers/applyexp.jl") + include("apply.jl") include("inner.jl") include("normalize.jl") diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl new file mode 100644 index 00000000..7c033d8e --- /dev/null +++ b/src/solvers/adapters.jl @@ -0,0 +1,32 @@ + +# +# TupleRegionIterator +# +# Adapts outputs to be (region, region_kwargs) tuples +# +# More generic design? maybe just assuming RegionIterator +# or its outputs implement some interface function that +# generates each tuple? +# + +mutable struct TupleRegionIterator{RegionIter} + region_iterator::RegionIter +end + +region_iterator(T::TupleRegionIterator) = T.region_iterator + +function Base.iterate(T::TupleRegionIterator, which=1) + state = iterate(region_iterator(T), which) + isnothing(state) && return nothing + (current_region, region_kwargs) = current_region_plan(region_iterator(T)) + return (current_region, region_kwargs), last(state) +end + +""" + region_tuples(R::RegionIterator) + +The `region_tuples` adapter converts a RegionIterator into an +iterator which outputs a tuple of the form (current_region, current_region_kwargs) +at each step. +""" +region_tuples(R::RegionIterator) = TupleRegionIterator(R) diff --git a/src/solvers/alternating_update/alternating_update.jl b/src/solvers/alternating_update/alternating_update.jl deleted file mode 100644 index 69f965f0..00000000 --- a/src/solvers/alternating_update/alternating_update.jl +++ /dev/null @@ -1,143 +0,0 @@ -using ITensors: state -using NamedGraphs.GraphsExtensions: GraphsExtensions - -function alternating_update( - operator, - init_state::AbstractTTN; - nsweeps, # define default for each solver implementation - nsites, # define default for each level of solver implementation - updater, # this specifies the update performed locally - outputlevel=default_outputlevel(), - region_printer=default_region_printer, - sweep_printer=default_sweep_printer, - (sweep_observer!)=nothing, - (region_observer!)=nothing, - root_vertex=GraphsExtensions.default_root_vertex(init_state), - extracter_kwargs=(;), - extracter=default_extracter(), - updater_kwargs=(;), - inserter_kwargs=(;), - inserter=default_inserter(), - transform_operator_kwargs=(;), - transform_operator=default_transform_operator(), - kwargs..., -) - inserter_kwargs = (; inserter_kwargs..., kwargs...) - sweep_plans = default_sweep_plans( - nsweeps, - init_state; - root_vertex, - extracter, - extracter_kwargs, - updater, - updater_kwargs, - inserter, - inserter_kwargs, - transform_operator, - transform_operator_kwargs, - nsites, - ) - return alternating_update( - operator, - init_state, - sweep_plans; - outputlevel, - sweep_observer!, - region_observer!, - sweep_printer, - region_printer, - ) -end - -function alternating_update( - projected_operator, - init_state::AbstractTTN, - sweep_plans; - outputlevel=default_outputlevel(), - checkdone=default_checkdone(), # - (sweep_observer!)=nothing, - sweep_printer=default_sweep_printer,#? - (region_observer!)=nothing, - region_printer=default_region_printer, -) - state = copy(init_state) - @assert !isnothing(sweep_plans) - for which_sweep in eachindex(sweep_plans) - sweep_plan = sweep_plans[which_sweep] - sweep_time = @elapsed begin - for which_region_update in eachindex(sweep_plan) - state, projected_operator = region_update( - projected_operator, - state; - which_sweep, - sweep_plan, - region_printer, - (region_observer!), - which_region_update, - outputlevel, - ) - end - end - update_observer!( - sweep_observer!; state, which_sweep, sweep_time, outputlevel, sweep_plans - ) - !isnothing(sweep_printer) && - sweep_printer(; state, which_sweep, sweep_time, outputlevel, sweep_plans) - checkdone(; - state, - which_sweep, - outputlevel, - sweep_plan, - sweep_plans, - sweep_observer!, - region_observer!, - ) && break - end - return state -end - -function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwargs...) - projected_operator = ProjTTN(operator) - return alternating_update(projected_operator, init_state; kwargs...) -end - -function alternating_update( - operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs... -) - projected_operator = ProjTTN(operator) - return alternating_update(projected_operator, init_state, sweep_plans; kwargs...) -end - -#ToDo: Fix docstring. -""" - tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...) - tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...) - -Use the time dependent variational principle (TDVP) algorithm -to compute `exp(t*H)*init_state` using an efficient algorithm based -on alternating optimization of the MPS tensors and local Krylov -exponentiation of H. - -This version of `tdvp` accepts a representation of H as a -Vector of MPOs, Hs = [H1,H2,H3,...] such that H is defined -as H = H1+H2+H3+... -Note that this sum of MPOs is not actually computed; rather -the set of MPOs [H1,H2,H3,..] is efficiently looped over at -each step of the algorithm when optimizing the MPS. - -Returns: -* `state::MPS` - time-evolved MPS -""" -function alternating_update( - operators::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs... -) - projected_operators = ProjTTNSum(operators) - return alternating_update(projected_operators, init_state; kwargs...) -end - -function alternating_update( - operators::Vector{<:AbstractTTN}, init_state::AbstractTTN, sweep_plans; kwargs... -) - projected_operators = ProjTTNSum(operators) - return alternating_update(projected_operators, init_state, sweep_plans; kwargs...) -end diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl deleted file mode 100644 index c741c82a..00000000 --- a/src/solvers/alternating_update/region_update.jl +++ /dev/null @@ -1,77 +0,0 @@ -function region_update( - projected_operator, - state; - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - region_printer, - (region_observer!), -) - (region, region_kwargs) = sweep_plan[which_region_update] - (; - extracter, - extracter_kwargs, - updater, - updater_kwargs, - inserter, - inserter_kwargs, - transform_operator, - transform_operator_kwargs, - internal_kwargs, - ) = region_kwargs - - # ToDo: remove orthogonality center on vertex for generality - # region carries same information - if !isnothing(transform_operator) - projected_operator = transform_operator( - state, projected_operator; outputlevel, transform_operator_kwargs... - ) - end - state, projected_operator, phi = extracter( - state, projected_operator, region; extracter_kwargs..., internal_kwargs - ) - # create references, in case solver does (out-of-place) modify PH or state - state! = Ref(state) - projected_operator! = Ref(projected_operator) - # args passed by reference are supposed to be modified out of place - phi, info = updater( - phi; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - updater_kwargs..., - internal_kwargs, - ) - state = state![] - projected_operator = projected_operator![] - # ToDo: implement noise term as updater - #drho = nothing - #ortho = "left" #i guess with respect to ordered vertices that's valid but may be cleaner to use next_region logic - #if noise > 0.0 && isforward(direction) - # drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees... - # so noiseterm is a solver - #end - #if isa(region, AbstractEdge) && - state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs) - all_kwargs = (; - which_region_update, - sweep_plan, - total_sweep_steps=length(sweep_plan), - end_of_sweep=(which_region_update == length(sweep_plan)), - state, - region, - which_sweep, - spec, - outputlevel, - info..., - region_kwargs..., - internal_kwargs..., - ) - update_observer!(region_observer!; all_kwargs...) - !(isnothing(region_printer)) && region_printer(; all_kwargs...) - return state, projected_operator -end diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl new file mode 100644 index 00000000..dfaf6e99 --- /dev/null +++ b/src/solvers/applyexp.jl @@ -0,0 +1,103 @@ +using Printf: @printf +import ConstructionBase: setproperties + +@kwdef mutable struct ApplyExpProblem{State} + state::State + operator + current_time::Number = 0.0 +end + +ITensorNetworks.state(A::ApplyExpProblem) = A.state +operator(A::ApplyExpProblem) = A.operator +current_time(A::ApplyExpProblem) = A.current_time + +function region_plan(tdvp::ApplyExpProblem; nsites, time_step, sweep_kwargs...) + return tdvp_regions(state(tdvp), time_step; nsites, sweep_kwargs...) +end + +function updater( + A::ApplyExpProblem, + local_state, + region_iterator; + nsites, + time_step, + solver=runge_kutta_solver, + outputlevel, + kws..., +) + local_state, info = solver(x->optimal_map(operator(A), x), time_step, local_state; kws...) + + if nsites==1 + curr_reg = current_region(region_iterator) + next_reg = next_region(region_iterator) + if !isnothing(next_reg) && next_reg != curr_reg + next_edge = first(edge_sequence_between_regions(state(A), curr_reg, next_reg)) + v1, v2 = src(next_edge), dst(next_edge) + psi = copy(state(A)) + psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2])) + shifted_operator = position(operator(A), psi, NamedEdge(v1=>v2)) + R_t, _ = solver(x->optimal_map(shifted_operator, x), -time_step, R; kws...) + local_state = psi[v1]*R_t + end + end + + curr_time = current_time(A) + time_step + A = setproperties(A; current_time=curr_time) + + return A, local_state +end + +function applyexp_sweep_printer( + region_iterator; outputlevel, sweep, nsweeps, process_time=identity, kws... +) + if outputlevel >= 1 + T = problem(region_iterator) + @printf(" Current time = %s, ", process_time(current_time(T))) + @printf("maxlinkdim=%d", maxlinkdim(state(T))) + println() + flush(stdout) + end +end + +function applyexp( + init_prob, + exponents; + extracter_kwargs=(;), + updater_kwargs=(;), + inserter_kwargs=(;), + outputlevel=0, + nsites=1, + tdvp_order=4, + sweep_printer=applyexp_sweep_printer, + kws..., +) + time_steps = diff([zero(eltype(exponents)); exponents])[2:end] + sweep_kws = (; + outputlevel, extracter_kwargs, inserter_kwargs, nsites, tdvp_order, updater_kwargs + ) + kws_array = [(; sweep_kws..., time_step=t) for t in time_steps] + sweep_iter = sweep_iterator(init_prob, kws_array) + converged_prob = sweep_solve(sweep_iter; outputlevel, sweep_printer, kws...) + return state(converged_prob) +end + +function applyexp(H, init_state, exponents; kws...) + init_prob = ApplyExpProblem(; + state=permute_indices(init_state), operator=ProjTTN(permute_indices(H)) + ) + return applyexp(init_prob, exponents; kws...) +end + +process_real_times(z) = round(-imag(z); digits=10) + +function time_evolve( + H, + init_state, + time_points; + process_time=process_real_times, + sweep_printer=(a...; k...)->applyexp_sweep_printer(a...; process_time, k...), + kws..., +) + exponents = [-im*t for t in time_points] + return applyexp(H, init_state, exponents; sweep_printer, kws...) +end diff --git a/src/solvers/defaults.jl b/src/solvers/defaults.jl deleted file mode 100644 index 5f25e087..00000000 --- a/src/solvers/defaults.jl +++ /dev/null @@ -1,66 +0,0 @@ -using Printf: @printf, @sprintf -default_outputlevel() = 0 -default_nsites() = 2 -default_nsweeps() = 1 #? or nothing? -default_extracter() = default_extracter -default_inserter() = default_inserter -default_checkdone() = (; kws...) -> false -default_transform_operator() = nothing - -format(x) = @sprintf("%s", x) -format(x::AbstractFloat) = @sprintf("%.1E", x) - -function default_region_printer(; - inserter_kwargs, - outputlevel, - state, - sweep_plan, - spec, - which_region_update, - which_sweep, - kwargs..., -) - if outputlevel >= 2 - region = first(sweep_plan[which_region_update]) - @printf("Sweep %d, region=%s \n", which_sweep, region) - print(" Truncated using") - for key in [:cutoff, :maxdim, :mindim] - if haskey(inserter_kwargs, key) - print(" ", key, "=", format(inserter_kwargs[key])) - end - end - println() - if spec != nothing - @printf( - " Trunc. err=%.2E, bond dimension %d\n", - spec.truncerr, - linkdim(state, edgetype(state)(region...)) - ) - end - flush(stdout) - end -end - -#ToDo: Implement sweep_time_printer more generally -#ToDo: Implement more printers -#ToDo: Move to another file? -function default_sweep_time_printer(; outputlevel, which_sweep, kwargs...) - if outputlevel >= 1 - sweeps_per_step = order ÷ 2 - if which_sweep % sweeps_per_step == 0 - current_time = (which_sweep / sweeps_per_step) * time_step - println("Current time (sweep $which_sweep) = ", round(current_time; digits=3)) - end - end - return nothing -end - -function default_sweep_printer(; outputlevel, state, which_sweep, sweep_time, kwargs...) - if outputlevel >= 1 - print("After sweep ", which_sweep, ":") - print(" maxlinkdim=", maxlinkdim(state)) - print(" cpu_time=", round(sweep_time; digits=3)) - println() - flush(stdout) - end -end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl new file mode 100644 index 00000000..0edc26dd --- /dev/null +++ b/src/solvers/eigsolve.jl @@ -0,0 +1,76 @@ +using Printf: @printf +import ConstructionBase: setproperties + +@kwdef mutable struct EigsolveProblem{State,Operator} + state::State + operator::Operator + eigenvalue::Number = Inf +end + +eigenvalue(E::EigsolveProblem) = E.eigenvalue +ITensorNetworks.state(E::EigsolveProblem) = E.state +operator(E::EigsolveProblem) = E.operator + +function updater( + E::EigsolveProblem, + local_state, + region_iterator; + outputlevel, + solver=eigsolve_solver, + kws..., +) + eigval, local_state = solver(ψ->optimal_map(operator(E), ψ), local_state; kws...) + E = setproperties(E; eigenvalue=eigval) + if outputlevel >= 2 + @printf(" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(E)) + end + return E, local_state +end + +function eigsolve_sweep_printer(region_iterator; outputlevel, sweep, nsweeps, kws...) + if outputlevel >= 1 + if nsweeps >= 10 + @printf("After sweep %02d/%d ", sweep, nsweeps) + else + @printf("After sweep %d/%d ", sweep, nsweeps) + end + E = problem(region_iterator) + @printf("eigenvalue=%.12f ", eigenvalue(E)) + @printf("maxlinkdim=%d", maxlinkdim(state(E))) + println() + flush(stdout) + end +end + +function eigsolve( + init_prob; + nsweeps, + nsites=1, + outputlevel=0, + extracter_kwargs=(;), + updater_kwargs=(;), + inserter_kwargs=(;), + sweep_printer=eigsolve_sweep_printer, + kws..., +) + sweep_iter = sweep_iterator( + init_prob, + nsweeps; + nsites, + outputlevel, + extracter_kwargs, + updater_kwargs, + inserter_kwargs, + ) + prob = sweep_solve(sweep_iter; outputlevel, sweep_printer, kws...) + return eigenvalue(prob), state(prob) +end + +function eigsolve(H, init_state; kws...) + init_prob = EigsolveProblem(; + state=permute_indices(init_state), operator=ProjTTN(permute_indices(H)) + ) + return eigsolve(init_prob; kws...) +end + +dmrg(args...; kws...) = eigsolve(args...; kws...) diff --git a/src/solvers/extract/extract.jl b/src/solvers/extract/extract.jl deleted file mode 100644 index 1013d1bd..00000000 --- a/src/solvers/extract/extract.jl +++ /dev/null @@ -1,28 +0,0 @@ -# Here extract_local_tensor and insert_local_tensor -# are essentially inverse operations, adapted for different kinds of -# algorithms and networks. -# -# In the simplest case, exact_local_tensor contracts together a few -# tensors of the network and returns the result, while -# insert_local_tensors takes that tensor and factorizes it back -# apart and puts it back into the network. -# - -function default_extracter(state, projected_operator, region; internal_kwargs) - if isa(region, AbstractEdge) - # TODO: add functionality for orthogonalizing onto a bond so that can be called instead - vsrc, vdst = src(region), dst(region) - state = orthogonalize(state, vsrc) - left_inds = uniqueinds(state[vsrc], state[vdst]) - U, S, V = svd( - state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region) - ) - state[vsrc] = U - local_tensor = S * V - else - state = orthogonalize(state, region) - local_tensor = prod(state[v] for v in region) - end - projected_operator = position(projected_operator, state, region) - return state, projected_operator, local_tensor -end diff --git a/src/solvers/extracter.jl b/src/solvers/extracter.jl new file mode 100644 index 00000000..1bf9f127 --- /dev/null +++ b/src/solvers/extracter.jl @@ -0,0 +1,17 @@ +import ConstructionBase: setproperties + +function extracter(problem, region_iterator; sweep, trunc=(;), kws...) + trunc = truncation_parameters(sweep; trunc...) + region = current_region(region_iterator) + psi = orthogonalize(state(problem), region) + local_state = prod(psi[v] for v in region) + problem = setproperties(problem; state=psi) + + problem, local_state = subspace_expand( + problem, local_state, region_iterator; sweep, trunc, kws... + ) + + shifted_operator = position(operator(problem), state(problem), region) + + return setproperties(problem; operator=shifted_operator), local_state +end diff --git a/src/solvers/insert/insert.jl b/src/solvers/insert/insert.jl deleted file mode 100644 index 01fb35bd..00000000 --- a/src/solvers/insert/insert.jl +++ /dev/null @@ -1,46 +0,0 @@ -# Here extract_local_tensor and insert_local_tensor -# are essentially inverse operations, adapted for different kinds of -# algorithms and networks. - -# TODO: use dense TTN constructor to make this more general. -function default_inserter( - state::AbstractTTN, - phi::ITensor, - region; - normalize=false, - maxdim=nothing, - mindim=nothing, - cutoff=nothing, - internal_kwargs, -) - state = copy(state) - spec = nothing - if length(region) == 2 - v = last(region) - e = edgetype(state)(first(region), last(region)) - indsTe = inds(state[first(region)]) - L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff) - state[first(region)] = L - else - v = only(region) - end - state[v] = phi - state = set_ortho_region(state, [v]) - normalize && (state[v] /= norm(state[v])) - return state, spec -end - -function default_inserter( - state::AbstractTTN, - phi::ITensor, - region::NamedEdge; - cutoff=nothing, - maxdim=nothing, - mindim=nothing, - normalize=false, - internal_kwargs, -) - state[dst(region)] *= phi - state = set_ortho_region(state, [dst(region)]) - return state, nothing -end diff --git a/src/solvers/inserter.jl b/src/solvers/inserter.jl new file mode 100644 index 00000000..2180fb14 --- /dev/null +++ b/src/solvers/inserter.jl @@ -0,0 +1,34 @@ +import ConstructionBase: setproperties +import NamedGraphs: edgetype + +function inserter( + problem, + local_tensor, + region_iterator; + normalize=false, + set_orthogonal_region=true, + sweep, + trunc=(;), + kws..., +) + trunc = truncation_parameters(sweep; trunc...) + + region = current_region(region_iterator) + psi = copy(state(problem)) + if length(region) == 1 + C = local_tensor + elseif length(region) == 2 + e = edgetype(psi)(first(region), last(region)) + indsTe = inds(psi[first(region)]) + tags = ITensors.tags(psi, e) + U, C, _ = factorize(local_tensor, indsTe; tags, trunc...) + @preserve_graph psi[first(region)] = U + else + error("Region of length $(length(region)) not currently supported") + end + v = last(region) + @preserve_graph psi[v] = C + psi = set_orthogonal_region ? set_ortho_region(psi, [v]) : psi + normalize && @preserve_graph psi[v] = psi[v] / norm(psi[v]) + return setproperties(problem; state=psi) +end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl new file mode 100644 index 00000000..aaf530cf --- /dev/null +++ b/src/solvers/iterators.jl @@ -0,0 +1,104 @@ +# +# SweepIterator +# + +mutable struct SweepIterator + sweep_kws + region_iter + which_sweep::Int +end + +problem(S::SweepIterator) = problem(S.region_iter) + +Base.length(S::SweepIterator) = length(S.sweep_kws) + +function Base.iterate(S::SweepIterator, which=nothing) + if isnothing(which) + sweep_kws_state = iterate(S.sweep_kws) + else + sweep_kws_state = iterate(S.sweep_kws, which) + end + isnothing(sweep_kws_state) && return nothing + current_sweep_kws, next = sweep_kws_state + + if !isnothing(which) + S.region_iter = region_iterator( + problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws... + ) + end + S.which_sweep += 1 + return S.region_iter, next +end + +function sweep_iterator(problem, sweep_kws) + region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...) + return SweepIterator(sweep_kws, region_iter, 1) +end + +function sweep_iterator(problem, nsweeps::Integer; sweep_kws...) + return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps)) +end + +# +# RegionIterator +# + +@kwdef mutable struct RegionIterator{Problem,RegionPlan} + problem::Problem + region_plan::RegionPlan + which_region::Int = 1 +end + +problem(R::RegionIterator) = R.problem +current_region_plan(R::RegionIterator) = R.region_plan[R.which_region] +current_region(R::RegionIterator) = current_region_plan(R)[1] +region_kwargs(R::RegionIterator) = current_region_plan(R)[2] +function previous_region(R::RegionIterator) + R.which_region==1 ? nothing : R.region_plan[R.which_region - 1][1] +end +function next_region(R::RegionIterator) + R.which_region==length(R.region_plan) ? nothing : R.region_plan[R.which_region + 1][1] +end +is_last_region(R::RegionIterator) = isnothing(next_region(R)) + +function Base.iterate(R::RegionIterator, which=1) + R.which_region = which + region_plan_state = iterate(R.region_plan, which) + isnothing(region_plan_state) && return nothing + (current_region, region_kwargs), next = region_plan_state + R.problem = region_iterator_action(problem(R), R; region_kwargs...) + return R, next +end + +# +# Functions associated with RegionIterator +# + +function region_iterator(problem; sweep_kwargs...) + return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...)) +end + +function region_iterator_action( + problem, + region_iterator; + extracter_kwargs=(;), + updater_kwargs=(;), + inserter_kwargs=(;), + sweep, + kws..., +) + problem, local_state = extracter( + problem, region_iterator; extracter_kwargs..., sweep, kws... + ) + problem, local_state = updater( + problem, local_state, region_iterator; updater_kwargs..., kws... + ) + problem = inserter( + problem, local_state, region_iterator; sweep, inserter_kwargs..., kws... + ) + return problem +end + +function region_plan(problem; kws...) + return euler_sweep(state(problem); kws...) +end diff --git a/src/solvers/local_solvers/contract.jl b/src/solvers/local_solvers/contract.jl deleted file mode 100644 index bffefdef..00000000 --- a/src/solvers/local_solvers/contract.jl +++ /dev/null @@ -1,13 +0,0 @@ -function contract_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, -) - P = projected_operator![] - return contract_ket(P, ITensor(one(Bool))), (;) -end diff --git a/src/solvers/local_solvers/dmrg_x.jl b/src/solvers/local_solvers/dmrg_x.jl deleted file mode 100644 index 3c3ae429..00000000 --- a/src/solvers/local_solvers/dmrg_x.jl +++ /dev/null @@ -1,22 +0,0 @@ -using ITensors: ITensor, contract, dag, onehot, uniqueind -using ITensors.NDTensors: array -using LinearAlgebra: eigen - -function dmrg_x_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, -) - H = contract(projected_operator![], ITensor(true)) - D, U = eigen(H; ishermitian=true) - u = uniqueind(U, H) - max_overlap, max_ind = findmax(abs, array(dag(init) * U)) - U_max = U * dag(onehot(u => max_ind)) - eigvals = [((onehot(u => max_ind)' * D) * dag(onehot(u => max_ind)))[]] - return U_max, (; eigvals) -end diff --git a/src/solvers/local_solvers/eigsolve.jl b/src/solvers/local_solvers/eigsolve.jl index ed993d80..96f1c9fb 100644 --- a/src/solvers/local_solvers/eigsolve.jl +++ b/src/solvers/local_solvers/eigsolve.jl @@ -1,25 +1,20 @@ -using KrylovKit: eigsolve +using KrylovKit: KrylovKit -function eigsolve_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, +function eigsolve_solver( + operator, + init, + howmany=1; which_eigval=:SR, ishermitian=true, - tol=1e-14, + tol=1E-14, krylovdim=3, maxiter=1, verbosity=0, eager=false, + kws..., ) - howmany = 1 - vals, vecs, info = eigsolve( - projected_operator![], + vals, vecs, info = KrylovKit.eigsolve( + operator, init, howmany, which_eigval; @@ -30,5 +25,5 @@ function eigsolve_updater( verbosity, eager, ) - return vecs[1], (; info, eigvals=vals) + return vals[1], vecs[1] end diff --git a/src/solvers/local_solvers/exponentiate.jl b/src/solvers/local_solvers/exponentiate.jl index c70a91c5..b0929916 100644 --- a/src/solvers/local_solvers/exponentiate.jl +++ b/src/solvers/local_solvers/exponentiate.jl @@ -1,14 +1,9 @@ -using KrylovKit: exponentiate +using KrylovKit: KrylovKit -function exponentiate_updater( +function exponentiate_solver( + operator, + time, init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, krylovdim=30, maxiter=100, verbosity=0, @@ -16,11 +11,11 @@ function exponentiate_updater( ishermitian=true, issymmetric=true, eager=true, + kws..., ) - (; time_step) = internal_kwargs - result, exp_info = exponentiate( - projected_operator![], - time_step, + result, exp_info = KrylovKit.exponentiate( + operator, + time, init; eager, krylovdim, @@ -30,5 +25,5 @@ function exponentiate_updater( ishermitian, issymmetric, ) - return result, (; info=exp_info) + return result, exp_info end diff --git a/src/solvers/local_solvers/linsolve.jl b/src/solvers/local_solvers/linsolve.jl deleted file mode 100644 index c5b8c4c6..00000000 --- a/src/solvers/local_solvers/linsolve.jl +++ /dev/null @@ -1,26 +0,0 @@ -using KrylovKit: linsolve - -function linsolve_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - region_kwargs, - ishermitian=false, - tol=1E-14, - krylovdim=30, - maxiter=100, - verbosity=0, - a₀, - a₁, -) - P = projected_operator![] - b = dag(only(proj_mps(P))) - x, info = linsolve( - P, b, init, a₀, a₁; ishermitian=false, tol, krylovdim, maxiter, verbosity - ) - return x, (;) -end diff --git a/src/solvers/local_solvers/runge_kutta.jl b/src/solvers/local_solvers/runge_kutta.jl new file mode 100644 index 00000000..03294f0d --- /dev/null +++ b/src/solvers/local_solvers/runge_kutta.jl @@ -0,0 +1,25 @@ + +function runge_kutta_2(H, t, ψ0) + Hψ = H(ψ0) + H2ψ = H(Hψ) + return (ψ0 + t * Hψ + (t^2 / 2) * H2ψ) +end + +function runge_kutta_4(H, t, ψ0) + k1 = H(ψ0) + k2 = k1 + (t / 2) * H(k1) + k3 = k1 + (t / 2) * H(k2) + k4 = k1 + t * H(k3) + return ψ0 + (t / 6) * (k1 + 2 * k2 + 2 * k3 + k4) +end + +function runge_kutta_solver(H, time, ψ; order=4, kws...) + if order == 4 + Hψ = runge_kutta_4(H, time, ψ) + elseif order == 2 + Hψ = runge_kutta_2(H, time, ψ) + else + error("For runge_kutta_solver, must specify `order` keyword") + end + return Hψ, (;) +end diff --git a/src/solvers/operator_map.jl b/src/solvers/operator_map.jl new file mode 100644 index 00000000..78a6f264 --- /dev/null +++ b/src/solvers/operator_map.jl @@ -0,0 +1,9 @@ + +function optimal_map(P::ProjTTN, ψ) + envs = [environment(P, e) for e in incident_edges(P)] + site_ops = [operator(P)[s] for s in sites(P)] + contract_list = [envs..., site_ops..., ψ] + sequence = contraction_sequence(contract_list; alg="optimal") + Pψ = contract(contract_list; sequence) + return noprime(Pψ) +end diff --git a/src/solvers/permute_indices.jl b/src/solvers/permute_indices.jl new file mode 100644 index 00000000..90526965 --- /dev/null +++ b/src/solvers/permute_indices.jl @@ -0,0 +1,16 @@ + +function permute_indices(tn) + si = siteinds(tn) + ptn = copy(tn) + for v in vertices(tn) + is = inds(tn[v]) + ls = setdiff(is, si[v]) + isempty(ls) && continue + new_is = [first(ls), si[v]...] + if length(ls) >= 2 + new_is = vcat(new_is, ls[2:end]) + end + ptn[v] = permute(tn[v], new_is) + end + return ptn +end diff --git a/src/solvers/contract.jl b/src/solvers/previous_interfaces/contract.jl similarity index 100% rename from src/solvers/contract.jl rename to src/solvers/previous_interfaces/contract.jl diff --git a/src/solvers/dmrg.jl b/src/solvers/previous_interfaces/dmrg.jl similarity index 100% rename from src/solvers/dmrg.jl rename to src/solvers/previous_interfaces/dmrg.jl diff --git a/src/solvers/dmrg_x.jl b/src/solvers/previous_interfaces/dmrg_x.jl similarity index 100% rename from src/solvers/dmrg_x.jl rename to src/solvers/previous_interfaces/dmrg_x.jl diff --git a/src/solvers/linsolve.jl b/src/solvers/previous_interfaces/linsolve.jl similarity index 100% rename from src/solvers/linsolve.jl rename to src/solvers/previous_interfaces/linsolve.jl diff --git a/src/solvers/tdvp.jl b/src/solvers/previous_interfaces/tdvp.jl similarity index 100% rename from src/solvers/tdvp.jl rename to src/solvers/previous_interfaces/tdvp.jl diff --git a/src/solvers/region_plans.jl b/src/solvers/region_plans.jl new file mode 100644 index 00000000..be858dfd --- /dev/null +++ b/src/solvers/region_plans.jl @@ -0,0 +1,73 @@ +import Graphs: AbstractGraph, AbstractEdge, edges, dst, src, vertices +import NamedGraphs: GraphsExtensions + +#function basic_path_regions(g::AbstractGraph; sweep_kwargs...) +# fwd_sweep = [([src(e), dst(e)], sweep_kwargs) for e in edges(g)] +# return [fwd_sweep..., reverse(fwd_sweep)...] +#end + +function tdvp_regions( + g::AbstractGraph, time_step; nsites=1, updater_kwargs, sweep_kwargs... +) + @assert nsites==1 + fwd_up_args = (; time=(time_step / 2), updater_kwargs...) + rev_up_args = (; time=(-time_step / 2), updater_kwargs...) + + fwd_sweep = [] + for e in edges(g) + push!(fwd_sweep, ([src(e)], (; updater_kwargs=fwd_up_args, sweep_kwargs...))) + push!(fwd_sweep, (e, (; updater_kwargs=rev_up_args, sweep_kwargs...))) + end + push!(fwd_sweep, ([dst(last(edges(g)))], (; updater_kwargs=fwd_up_args, sweep_kwargs...))) + + # Reverse regions as well as ordering of regions: + rev_sweep = [(reverse(rk[1]), rk[2]) for rk in reverse(fwd_sweep)] + + return [fwd_sweep..., rev_sweep...] +end + +function overlap(ea::AbstractEdge, eb::AbstractEdge) + return intersect([src(ea), dst(ea)], [src(eb), dst(eb)]) +end + +function forward_region(edges, which_edge; nsites=1, region_kwargs=(;)) + current_edge = edges[which_edge] + if nsites == 1 + #handle edge case + if current_edge == last(edges) + overlapping_vertex = only( + union([overlap(e, current_edge) for e in edges[1:(which_edge - 1)]]...) + ) + nonoverlapping_vertex = only( + setdiff([src(current_edge), dst(current_edge)], [overlapping_vertex]) + ) + return [ + ([overlapping_vertex], region_kwargs), ([nonoverlapping_vertex], region_kwargs) + ] + else + future_edges = edges[(which_edge + 1):end] + future_edges = isa(future_edges, AbstractEdge) ? [future_edges] : future_edges + overlapping_vertex = only(union([overlap(e, current_edge) for e in future_edges]...)) + nonoverlapping_vertex = only( + setdiff([src(current_edge), dst(current_edge)], [overlapping_vertex]) + ) + return [([nonoverlapping_vertex], region_kwargs)] + end + elseif nsites == 2 + return [([src(current_edge), dst(current_edge)], region_kwargs)] + end +end + +function basic_region_plan( + graph::AbstractGraph; + nsites, + root_vertex=GraphsExtensions.default_root_vertex(graph), + sweep_kwargs..., +) + edges = GraphsExtensions.post_order_dfs_edges(graph, root_vertex) + fwd_sweep = [ + forward_region(edges, i; nsites, region_kwargs=sweep_kwargs) for i in 1:length(edges) + ] + fwd_sweep = collect(Iterators.flatten(fwd_sweep)) + return [fwd_sweep..., reverse(fwd_sweep)...] +end diff --git a/src/solvers/region_plans/dfs_plans.jl b/src/solvers/region_plans/dfs_plans.jl new file mode 100644 index 00000000..facbaca6 --- /dev/null +++ b/src/solvers/region_plans/dfs_plans.jl @@ -0,0 +1,22 @@ +import Graphs: dst, src +import NamedGraphs.GraphsExtensions: + default_root_vertex, post_order_dfs_edges, post_order_dfs_vertices + +function post_order_dfs_plan( + graph; nsites, root_vertex=default_root_vertex(graph), sweep_kwargs... +) + if nsites == 1 + vertices = post_order_dfs_vertices(graph, root_vertex) + fwd_sweep = [([v], sweep_kwargs) for v in vertices] + elseif nsites == 2 + edges = post_order_dfs_edges(graph, root_vertex) + fwd_sweep = [([src(e), dst(e)], sweep_kwargs) for e in edges] + end + return fwd_sweep +end + +function post_order_dfs_sweep(args...; kws...) + fwd_sweep = post_order_dfs_plan(args...; kws...) + rev_sweep = [(reverse(reg_kws[1]), reg_kws[2]) for reg_kws in reverse(fwd_sweep)] + return [fwd_sweep..., rev_sweep...] +end diff --git a/src/solvers/region_plans/euler_plans.jl b/src/solvers/region_plans/euler_plans.jl new file mode 100644 index 00000000..f3b9a96e --- /dev/null +++ b/src/solvers/region_plans/euler_plans.jl @@ -0,0 +1,13 @@ +import Graphs: dst, src +import NamedGraphs.GraphsExtensions: default_root_vertex + +function euler_sweep(graph; nsites, root_vertex=default_root_vertex(graph), sweep_kwargs...) + if nsites == 1 + vertices = euler_tour_vertices(graph, root_vertex) + sweep = [([v], sweep_kwargs) for v in vertices] + elseif nsites == 2 + edges = euler_tour_edges(graph, root_vertex) + sweep = [([src(e), dst(e)], sweep_kwargs) for e in edges] + end + return sweep +end diff --git a/src/solvers/region_plans/euler_tour.jl b/src/solvers/region_plans/euler_tour.jl new file mode 100644 index 00000000..6aeb0029 --- /dev/null +++ b/src/solvers/region_plans/euler_tour.jl @@ -0,0 +1,48 @@ +import Graphs: dst, edges, src, vertices +import NamedGraphs as ng + +function compute_adjacencies(G) + adj = Dict(v => Vector{ng.vertextype(G)}() for v in vertices(G)) + for e in edges(G) + push!(adj[src(e)], dst(e)) + push!(adj[dst(e)], src(e)) + end + return adj +end + +function euler_tour_edges(G, start_vertex) + adj = compute_adjacencies(G) + etype = ng.edgetype(G) + vtype = ng.vertextype(G) + visited = Set{Tuple{vtype,vtype}}() + tour = Vector{etype}() + stack = [start_vertex] + while !isempty(stack) + u = stack[end] + pushed = false + for v in adj[u] + if (u, v) ∉ visited + push!(visited, (u, v)) + push!(visited, (v, u)) + push!(tour, etype(u => v)) + push!(stack, v) + pushed = true + break # handle one neighbor at a time + end + end + if !pushed + pop!(stack) + if !isempty(stack) + v = stack[end] + push!(tour, etype(u => v)) # Backtracking step + end + end + end + return tour +end + +function euler_tour_vertices(G, start_vertex) + edges = euler_tour_edges(G, start_vertex) + isempty(edges) && return Vector{eltype(vertices(G))}[] + return [src(edges[1]), dst.(edges)...] +end diff --git a/src/solvers/region_plans/tdvp_region_plans.jl b/src/solvers/region_plans/tdvp_region_plans.jl new file mode 100644 index 00000000..0624d666 --- /dev/null +++ b/src/solvers/region_plans/tdvp_region_plans.jl @@ -0,0 +1,44 @@ +function tdvp_sub_time_steps(tdvp_order) + if tdvp_order == 1 + return [1.0] + elseif tdvp_order == 2 + return [1 / 2, 1 / 2] + elseif tdvp_order == 4 + s = (2 - 2^(1 / 3))^(-1) + return [s/2, s/2, 1/2 - s, 1/2 - s, s/2, s/2] + else + error("TDVP order of $tdvp_order not supported") + end +end + +function first_order_sweep( + graph, time_step, dir=Base.Forward; updater_kwargs, nsites, kws... +) + basic_fwd_sweep = post_order_dfs_plan(graph; nsites, kws...) + updater_kwargs = (; nsites, time_step, updater_kwargs...) + sweep = [] + for (j, (region, region_kws)) in enumerate(basic_fwd_sweep) + push!(sweep, (region, (; nsites, updater_kwargs, region_kws...))) + if length(region) == 2 && j < length(basic_fwd_sweep) + rev_kwargs = (; updater_kwargs..., time_step=(-updater_kwargs.time_step)) + push!(sweep, ([last(region)], (; updater_kwargs=rev_kwargs, region_kws...))) + end + end + if dir==Base.Reverse + # Reverse regions as well as ordering of regions + sweep = [(reverse(reg_kws[1]), reg_kws[2]) for reg_kws in reverse(sweep)] + end + return sweep +end + +function tdvp_regions(graph, time_step; updater_kwargs, tdvp_order, nsites, kws...) + sweep_plan = [] + for (step, weight) in enumerate(tdvp_sub_time_steps(tdvp_order)) + dir = isodd(step) ? Base.Forward : Base.Reverse + append!( + sweep_plan, + first_order_sweep(graph, weight*time_step, dir; updater_kwargs, nsites, kws...), + ) + end + return sweep_plan +end diff --git a/src/solvers/subspace.jl b/src/solvers/subspace.jl new file mode 100644 index 00000000..7fccc8f0 --- /dev/null +++ b/src/solvers/subspace.jl @@ -0,0 +1,64 @@ +using ITensors: + commonind, + dag, + dim, + directsum, + dot, + hascommoninds, + Index, + norm, + onehot, + uniqueinds, + random_itensor + +# TODO: hoist num_expand default value out to a function or similar +function subspace_expand!( + problem::EigsolveProblem, local_tensor, region; prev_region, num_expand=4, kws... +) + if isnothing(prev_region) || isa(region, AbstractEdge) + return local_tensor + end + + prev_vertex_set = setdiff(prev_region, region) + (length(prev_vertex_set) != 1) && return local_tensor + prev_vertex = only(prev_vertex_set) + + psi = state(problem) + A = psi[prev_vertex] + + next_vertex = only(filter(v -> (it.hascommoninds(psi[v], A)), region)) + C = psi[next_vertex] + + # Analyze indices of A + # TODO: if "a" is missing, could supply a 1-dim index and put on both A and C? + a = commonind(A, C) + isnothing(a) && return local_tensor + basis_inds = uniqueinds(A, C) + + # Determine maximum value of num_expand + dim_basis = prod(dim.(basis_inds)) + num_expand = min(num_expand, dim_basis - dim(a)) + (num_expand <= 0) && return local_tensor + + # Build new subspace + function linear_map(w) + return w = w - A * (dag(A) * w) + end + random_vector() = random_itensor(basis_inds...) + Q = range_finder(linear_map, random_vector; max_rank=num_expand, oversample=0) + + # Direct sum new space with A to make Ax + qinds = [Index(1, "q$j") for j in 1:num_expand] + Q = [Q[j] * onehot(qinds[j] => 1) => qinds[j] for j in 1:num_expand] + Ax, sa = directsum(A => a, Q...) + + expander = dag(Ax) * A + psi[prev_vertex] = Ax + psi[next_vertex] = expander * C + + # TODO: avoid computing local tensor twice + # while also handling AbstractEdge region case + local_tensor = prod(psi[v] for v in region) + + return local_tensor +end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl new file mode 100644 index 00000000..7511901d --- /dev/null +++ b/src/solvers/subspace/densitymatrix.jl @@ -0,0 +1,73 @@ +using NamedGraphs.GraphsExtensions: incident_edges +using Printf: @printf + +function subspace_expand( + ::Backend"densitymatrix", + problem, + local_state::ITensor, + region_iterator; + north_pass=1, + expansion_factor=default_expansion_factor(), + max_expand=default_max_expand(), + trunc, + kws..., +) + region = current_region(region_iterator) + psi = copy(state(problem)) + + prev_vertex_set = setdiff(pos(operator(problem)), region) + (length(prev_vertex_set) != 1) && return problem, local_state + prev_vertex = only(prev_vertex_set) + A = psi[prev_vertex] + + next_vertices = filter(v -> (hascommoninds(psi[v], A)), region) + isempty(next_vertices) && return problem, local_state + next_vertex = only(next_vertices) + C = psi[next_vertex] + + a = commonind(A, C) + isnothing(a) && return problem, local_state + basis_size = prod(dim.(uniqueinds(A, C))) + + expanded_maxdim = compute_expansion( + dim(a), basis_size; expansion_factor, max_expand, trunc.maxdim + ) + expanded_maxdim <= 0 && return problem, local_state + trunc = (; trunc..., maxdim=expanded_maxdim) + + envs = environments(operator(problem)) + H = operator(operator(problem)) + sqrt_rho = A + for e in incident_edges(operator(problem)) + (src(e) ∈ region || dst(e) ∈ region) && continue + sqrt_rho *= envs[e] + end + sqrt_rho *= H[prev_vertex] + + conj_proj_A(T) = (T - prime(A)*(dag(prime(A))*T)) + for pass in 1:north_pass + sqrt_rho = conj_proj_A(sqrt_rho) + end + rho = sqrt_rho * dag(noprime(sqrt_rho)) + D, U = eigen(rho; trunc..., ishermitian=true) + + Uproj(T) = (T - prime(A, a)*(dag(prime(A, a))*T)) + for pass in 1:north_pass + U = Uproj(U) + end + if norm(dag(U)*A) > 1E-10 + @printf("Warning: |U*A| = %.3E in subspace expansion\n", norm(dag(U)*A)) + return problem, local_state + end + + Ax, ax = directsum(A=>a, U=>commonind(U, D)) + #println("Old space: ", space(a)) + #println("New space: ", space(ax)) + #ITensors.pause() + expander = dag(Ax) * A + psi[prev_vertex] = Ax + psi[next_vertex] = expander * C + local_state = expander*local_state + + return setproperties(problem; state=psi), local_state +end diff --git a/src/solvers/subspace/ortho_subspace.jl b/src/solvers/subspace/ortho_subspace.jl new file mode 100644 index 00000000..7d7ca6c2 --- /dev/null +++ b/src/solvers/subspace/ortho_subspace.jl @@ -0,0 +1,77 @@ +using ITensors +using Graphs: AbstractEdge + +expand_space(χ::Integer, expansion_factor) = max(χ + 1, floor(Int, expansion_factor * χ)) + +function expand_space(χs::Vector{<:Pair}, expansion_factor) + return [q => expand_space(d, expansion_factor) for (q, d) in χs] +end + +# +# Alternative idea for "ortho" method: +# - Just make a random tensor with `basis_inds` on both sides +# (Kind of like a random density matrix.) +# - Symmetrize to make it Hermitian PSD. +# - Then do eigenvalue decomposition to get U at desired size. +# - Finally, project out space of A from U. +# + +function subspace_expand!( + ::Backend"ortho", + problem::EigsolveProblem, + local_tensor, + region_iterator; + cutoff=default_cutoff(), + maxdim=default_maxdim(), + mindim=default_mindim(), + expansion_factor=default_expansion_factor(), + max_expand=default_max_expand(), + kws..., +) + prev_region = previous_region(region_iterator) + region = current_region(region_iterator) + if isnothing(prev_region) || isa(region, AbstractEdge) + return local_tensor + end + + prev_vertex_set = setdiff(prev_region, region) + (length(prev_vertex_set) != 1) && return local_tensor + prev_vertex = only(prev_vertex_set) + + psi = state(problem) + A = psi[prev_vertex] + + next_vertices = filter(v -> (it.hascommoninds(psi[v], A)), region) + isempty(next_vertices) && return local_tensor + next_vertex = only(next_vertices) + C = psi[next_vertex] + + # Analyze indices of A + # TODO: if "a" is missing, could supply a 1-dim index and put on both A and C? + a = commonind(A, C) + isnothing(a) && return local_tensor + basis_inds = uniqueinds(A, C) + basis_size = prod(dim.(basis_inds)) + + ci = combinedind(combiner(basis_inds...)) + ax_space = expand_space(space(ci), expansion_factor) + ax = Index(ax_space, "ax") + + linear_map(w) = (w - A * (dag(A) * w)) + Y = linear_map(random_itensor(basis_inds, dag(ax))) + expand_maxdim = compute_expansion( + dim(a), basis_size; expansion_factor, max_expand, maxdim + ) + (norm(Y) <= 1E-15 || expand_maxdim <= 0) && return local_tensor + Ux, S, V = svd(Y, basis_inds; cutoff=1E-14, maxdim=expand_maxdim, lefttags="ux,Link") + + Ux = linear_map(Ux) + ux = commonind(Ux, S) + Ax, sa = directsum(A => a, Ux => ux) + expander = dag(Ax) * A + psi[prev_vertex] = Ax + psi[next_vertex] = expander * C + local_tensor = expander*local_tensor + + return local_tensor +end diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl new file mode 100644 index 00000000..7a332c3b --- /dev/null +++ b/src/solvers/subspace/subspace.jl @@ -0,0 +1,48 @@ +using NDTensors: NDTensors +using NDTensors.BackendSelection: Backend, @Backend_str +import ConstructionBase: setproperties + +default_expansion_factor() = 1.5 +default_max_expand() = typemax(Int) + +function subspace_expand( + problem, local_state, region_iterator; subspace_algorithm=nothing, sweep, trunc, kws... +) + return subspace_expand( + Backend(subspace_algorithm), problem, local_state, region_iterator; trunc, kws... + ) +end + +function subspace_expand(backend, problem, local_state, region_iterator; kws...) + error( + "Subspace expansion (subspace_expand!) not defined for requested combination of subspace_algorithm and problem types", + ) +end + +function subspace_expand( + backend::Backend{:nothing}, problem, local_state, region_iterator; kws... +) + problem, local_state +end + +function compute_expansion( + current_dim, + basis_size; + expansion_factor=default_expansion_factor(), + max_expand=default_max_expand(), + maxdim=default_maxdim(), +) + # Note: expand_maxdim will be *added* to current bond dimension + # Obtain expand_maxdim from expansion_factor + expand_maxdim = ceil(Int, expansion_factor * current_dim) + # Enforce max_expand keyword + expand_maxdim = min(max_expand, expand_maxdim) + + # Restrict expand_maxdim below theoretical upper limit + expand_maxdim = min(basis_size-current_dim, expand_maxdim) + # Enforce total maxdim setting (e.g. used in inserter step) + expand_maxdim = min(maxdim-current_dim, expand_maxdim) + # Ensure expand_maxdim is non-negative + expand_maxdim = max(0, expand_maxdim) + return expand_maxdim +end diff --git a/src/solvers/sweep_plans/sweep_plans.jl b/src/solvers/sweep_plans/sweep_plans.jl deleted file mode 100644 index dda6dd96..00000000 --- a/src/solvers/sweep_plans/sweep_plans.jl +++ /dev/null @@ -1,221 +0,0 @@ -using Graphs: AbstractEdge, dst, src -using NamedGraphs.GraphsExtensions: GraphsExtensions - -direction(step_number) = isodd(step_number) ? Base.Forward : Base.Reverse - -function overlap(edge_a::AbstractEdge, edge_b::AbstractEdge) - return intersect(support(edge_a), support(edge_b)) -end - -function support(edge::AbstractEdge) - return [src(edge), dst(edge)] -end - -support(r) = r - -function reverse_region(edges, which_edge; reverse_edge=false, nsites=1, region_kwargs=(;)) - current_edge = edges[which_edge] - if nsites == 1 - !reverse_edge && return [(current_edge, region_kwargs)] - reverse_edge && return [(reverse(current_edge), region_kwargs)] - elseif nsites == 2 - if last(edges) == current_edge - return () - end - future_edges = edges[(which_edge + 1):end] - future_edges = isa(future_edges, AbstractEdge) ? [future_edges] : future_edges - #error if more than single vertex overlap - overlapping_vertex = only(union([overlap(e, current_edge) for e in future_edges]...)) - return [([overlapping_vertex], region_kwargs)] - end -end - -function forward_region(edges, which_edge; nsites=1, region_kwargs=(;)) - if nsites == 1 - current_edge = edges[which_edge] - #handle edge case - if current_edge == last(edges) - overlapping_vertex = only( - union([overlap(e, current_edge) for e in edges[1:(which_edge - 1)]]...) - ) - nonoverlapping_vertex = only( - setdiff([src(current_edge), dst(current_edge)], [overlapping_vertex]) - ) - return [ - ([overlapping_vertex], region_kwargs), ([nonoverlapping_vertex], region_kwargs) - ] - else - future_edges = edges[(which_edge + 1):end] - future_edges = isa(future_edges, AbstractEdge) ? [future_edges] : future_edges - overlapping_vertex = only(union([overlap(e, current_edge) for e in future_edges]...)) - nonoverlapping_vertex = only( - setdiff([src(current_edge), dst(current_edge)], [overlapping_vertex]) - ) - return [([nonoverlapping_vertex], region_kwargs)] - end - elseif nsites == 2 - current_edge = edges[which_edge] - return [([src(current_edge), dst(current_edge)], region_kwargs)] - end -end - -function forward_sweep( - dir::Base.ForwardOrdering, - graph::AbstractGraph; - root_vertex=GraphsExtensions.default_root_vertex(graph), - reverse_edges=false, - region_kwargs, - reverse_kwargs=region_kwargs, - reverse_step=false, - kwargs..., -) - edges = post_order_dfs_edges(graph, root_vertex) - regions = map(eachindex(edges)) do i - forward_region(edges, i; region_kwargs, kwargs...) - end - regions = collect(flatten(regions)) - if reverse_step - reverse_regions = map(eachindex(edges)) do i - reverse_region( - edges, i; reverse_edge=reverse_edges, region_kwargs=reverse_kwargs, kwargs... - ) - end - reverse_regions = collect(flatten(reverse_regions)) - _check_reverse_sweeps(regions, reverse_regions, graph; kwargs...) - regions = interleave(regions, reverse_regions) - end - - return regions -end - -#ToDo: is there a better name for this? unidirectional_sweep? traversal? -function forward_sweep(dir::Base.ReverseOrdering, args...; kwargs...) - return reverse(forward_sweep(Base.Forward, args...; reverse_edges=true, kwargs...)) -end - -function default_sweep_plans( - nsweeps, - init_state; - sweep_plan_func=default_sweep_plan, - root_vertex, - extracter, - extracter_kwargs, - updater, - updater_kwargs, - inserter, - inserter_kwargs, - transform_operator, - transform_operator_kwargs, - kwargs..., -) - extracter, updater, inserter, transform_operator = extend_or_truncate.( - (extracter, updater, inserter, transform_operator), nsweeps - ) - inserter_kwargs, updater_kwargs, extracter_kwargs, transform_operator_kwargs, kwargs = expand.( - ( - inserter_kwargs, - updater_kwargs, - extracter_kwargs, - transform_operator_kwargs, - NamedTuple(kwargs), - ), - nsweeps, - ) - sweep_plans = [] - for i in 1:nsweeps - sweep_plan = sweep_plan_func( - init_state; - root_vertex, - region_kwargs=(; - inserter=inserter[i], - inserter_kwargs=inserter_kwargs[i], - updater=updater[i], - updater_kwargs=updater_kwargs[i], - extracter=extracter[i], - extracter_kwargs=extracter_kwargs[i], - transform_operator=transform_operator[i], - transform_operator_kwargs=transform_operator_kwargs[i], - ), - kwargs[i]..., - ) - push!(sweep_plans, sweep_plan) - end - return sweep_plans -end - -function default_sweep_plan( - graph::AbstractGraph; - root_vertex=GraphsExtensions.default_root_vertex(graph), - region_kwargs, - nsites::Int, -) - return vcat( - [ - forward_sweep( - direction(half), - graph; - root_vertex, - nsites, - region_kwargs=(; internal_kwargs=(; half), region_kwargs...), - ) for half in 1:2 - ]..., - ) -end - -function tdvp_sweep_plan( - graph::AbstractGraph; - root_vertex=GraphsExtensions.default_root_vertex(graph), - region_kwargs, - reverse_step=true, - order::Int, - nsites::Int, - time_step::Number, - t_evolved::Number, -) - sweep_plan = [] - for (substep, fac) in enumerate(sub_time_steps(order)) - sub_time_step = time_step * fac - append!( - sweep_plan, - forward_sweep( - direction(substep), - graph; - root_vertex, - nsites, - region_kwargs=(; - internal_kwargs=(; substep, time_step=sub_time_step, t=t_evolved), - region_kwargs..., - ), - reverse_kwargs=(; - internal_kwargs=(; substep, time_step=(-sub_time_step), t=t_evolved), - region_kwargs..., - ), - reverse_step, - ), - ) - end - return sweep_plan -end - -#ToDo: Move to test. -function _check_reverse_sweeps(forward_sweep, reverse_sweep, graph; nsites, kwargs...) - fw_regions = first.(forward_sweep) - bw_regions = first.(reverse_sweep) - if nsites == 2 - fw_verts = flatten(fw_regions) - bw_verts = flatten(bw_regions) - for v in vertices(graph) - @assert isone(count(isequal(v), fw_verts) - count(isequal(v), bw_verts)) - end - elseif nsites == 1 - fw_verts = flatten(fw_regions) - bw_edges = bw_regions - for v in vertices(graph) - @assert isone(count(isequal(v), fw_verts)) - end - for e in edges(graph) - @assert isone(count(x -> (isequal(x, e) || isequal(x, reverse(e))), bw_edges)) - end - end - return true -end diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl new file mode 100644 index 00000000..f5c5862d --- /dev/null +++ b/src/solvers/sweep_solve.jl @@ -0,0 +1,40 @@ + +default_region_callback(problem; kws...) = nothing + +default_sweep_callback(problem; kws...) = nothing + +function default_sweep_printer(problem; outputlevel, sweep, nsweeps, kws...) + if outputlevel >= 1 + println("Done with sweep $sweep/$nsweeps") + end +end + +function sweep_solve( + sweep_iterator; + outputlevel=0, + region_callback=default_region_callback, + sweep_callback=default_sweep_callback, + sweep_printer=default_sweep_printer, + kwargs..., +) + for (sweep, region_iter) in enumerate(sweep_iterator) + for (region, region_kwargs) in region_tuples(region_iter) + region_callback( + problem(region_iter); + nsweeps=length(sweep_iterator), + outputlevel, + region, + region_kwargs, + sweep, + kwargs..., + ) + end + sweep_callback( + region_iter; nsweeps=length(sweep_iterator), outputlevel, sweep, kwargs... + ) + sweep_printer( + region_iter; nsweeps=length(sweep_iterator), outputlevel, sweep, kwargs... + ) + end + return problem(sweep_iterator) +end diff --git a/src/solvers/truncation_parameters.jl b/src/solvers/truncation_parameters.jl new file mode 100644 index 00000000..bcd7d940 --- /dev/null +++ b/src/solvers/truncation_parameters.jl @@ -0,0 +1,14 @@ +default_maxdim() = typemax(Int) +default_mindim() = 1 +default_cutoff() = 0.0 + +get_or_last(x, i::Integer) = (i >= length(x)) ? last(x) : x[i] + +function truncation_parameters( + sweep; cutoff=default_cutoff(), maxdim=default_maxdim(), mindim=default_mindim() +) + cutoff = get_or_last(cutoff, sweep) + mindim = get_or_last(mindim, sweep) + maxdim = get_or_last(maxdim, sweep) + return (; cutoff, mindim, maxdim) +end diff --git a/test/solvers/test_tree_dmrg.jl b/test/solvers/test_tree_dmrg.jl new file mode 100644 index 00000000..c5c8ab3d --- /dev/null +++ b/test/solvers/test_tree_dmrg.jl @@ -0,0 +1,64 @@ +using Test: @test, @testset +using ITensors +using TensorOperations # Needed to use contraction order finding +using ITensorNetworks: siteinds, ttn, dmrg +import Graphs: dst, edges, src +import ITensorMPS: OpSum + +include("utilities/simple_ed_methods.jl") +include("utilities/tree_graphs.jl") + +@testset "Tree DMRG" begin + outputlevel = 0 + + g = build_tree(; nbranch=3, nbranch_sites=3) + + sites = siteinds("S=1/2", g) + + # Make Heisenberg model Hamiltonian + h = OpSum() + for edge in edges(sites) + i, j = src(edge), dst(edge) + h += "Sz", i, "Sz", j + h += 1/2, "S+", i, "S-", j + h += 1/2, "S-", i, "S+", j + end + H = ttn(h, sites) + + # Make initial product state + state = Dict{Tuple{Int,Int},String}() + for (j, v) in enumerate(gr.vertices(sites)) + state[v] = iseven(j) ? "Up" : "Dn" + end + psi0 = ttn(state, sites) + + (outputlevel >= 1) && println("Computing exact ground state") + Ex, psix = ed_ground_state(H, psi0) + (outputlevel >= 1) && println("Ex = ", Ex) + + cutoff = 1E-5 + maxdim = 40 + nsweeps = 5 + + # + # Test 2-site DMRG without subspace expansion + # + nsites = 2 + trunc = (; cutoff, maxdim) + inserter_kwargs = (; trunc) + E, psi = dmrg(H, psi0; inserter_kwargs, nsites, nsweeps, outputlevel) + (outputlevel >= 1) && println("2-site DMRG energy = ", E) + @test abs(E-Ex) < 1E-5 + + # + # Test 1-site DMRG with subspace expansion + # + nsites = 1 + nsweeps = 5 + trunc = (; cutoff, maxdim) + extracter_kwargs = (; trunc, subspace_algorithm="densitymatrix") + inserter_kwargs = (; trunc) + E, psi = dmrg(H, psi0; extracter_kwargs, inserter_kwargs, nsites, nsweeps, outputlevel) + (outputlevel >= 1) && println("1-site+subspace DMRG energy = ", E) + @test abs(E-Ex) < 1E-5 +end diff --git a/test/solvers/test_tree_tdvp.jl b/test/solvers/test_tree_tdvp.jl new file mode 100644 index 00000000..ee6731fe --- /dev/null +++ b/test/solvers/test_tree_tdvp.jl @@ -0,0 +1,76 @@ +using Test: @test, @testset +using ITensors +using TensorOperations # Needed to use contraction order finding +import ITensorNetworks: dmrg, maxlinkdim, siteinds, time_evolve, ttn +import Graphs: add_vertex!, add_edge!, vertices +import NamedGraphs: NamedGraph +import ITensorMPS: OpSum + +function chain_plus_ancilla(; nchain) + g = NamedGraph() + for j in 1:nchain + add_vertex!(g, j) + end + for j in 1:(nchain - 1) + add_edge!(g, j=>j+1) + end + # Add ancilla vertex near middle of chain + add_vertex!(g, 0) + add_edge!(g, 0=>nchain÷2) + return g +end + +@testset "Tree TDVP on chain plus ancilla" begin + outputlevel = 0 + + N = 10 + g = chain_plus_ancilla(; nchain=N) + + sites = siteinds("S=1/2", g) + + # Make Heisenberg model Hamiltonian + h = OpSum() + for j in 1:(N - 1) + h += "Sz", j, "Sz", j+1 + h += 1/2, "S+", j, "S-", j+1 + h += 1/2, "S-", j, "S+", j+1 + end + H = ttn(h, sites) + + # Make initial product state + state = Dict{Int,String}() + for (j, v) in enumerate(vertices(sites)) + state[v] = iseven(j) ? "Up" : "Dn" + end + psi0 = ttn(state, sites) + + cutoff = 1E-10 + maxdim = 100 + nsweeps = 5 + + nsites = 2 + trunc = (; cutoff, maxdim) + E, gs_psi = dmrg(H, psi0; inserter_kwargs=(; trunc), nsites, nsweeps, outputlevel) + (outputlevel >= 1) && println("2-site DMRG energy = ", E) + + inserter_kwargs=(; trunc) + nsites = 1 + tmax = 0.10 + time_range = 0.0:0.02:tmax + psi1_t = time_evolve(H, gs_psi, time_range; inserter_kwargs, nsites, outputlevel) + (outputlevel >= 1) && println("Done with $nsites-site TDVP") + + @test norm(psi1_t) > 0.999 + + nsites = 2 + psi2_t = time_evolve(H, gs_psi, time_range; inserter_kwargs, nsites, outputlevel) + (outputlevel >= 1) && println("Done with $nsites-site TDVP") + @test norm(psi2_t) > 0.999 + + @test abs(inner(psi1_t, gs_psi)) > 0.99 + @test abs(inner(psi1_t, psi2_t)) > 0.99 + + # Test that accumulated phase angle is E*tmax + z = inner(psi1_t, gs_psi) + @test abs(atan(imag(z)/real(z)) - E*tmax) < 1E-4 +end diff --git a/test/solvers/utilities/simple_ed_methods.jl b/test/solvers/utilities/simple_ed_methods.jl new file mode 100644 index 00000000..c80d3be8 --- /dev/null +++ b/test/solvers/utilities/simple_ed_methods.jl @@ -0,0 +1,33 @@ +import ITensorNetworks: AbstractITensorNetwork + +function ed_ground_state(H, psi0) + ITensors.disable_warn_order() + H = prod(H) + psi = prod(psi0) + expH = exp(H*(-20.0)) + for napply in 1:10 + psi = noprime(expH*psi) + psi /= norm(psi) + end + E = scalar(prime(psi)*H*psi) + return E, psi +end + +function ed_time_evolution( + H::AbstractITensorNetwork, psi::AbstractITensorNetwork, time_points; normalize=false +) + ITensors.disable_warn_order() + H = prod(H) + psi = prod(psi) + exponents = [-im*t for t in time_points] + steps = diff([0.0, exponents...])[2:end] + H_map = ψ -> noprime(H*psi) + for step in steps + expH = exp(H * step) + psi = noprime(expH * psi) + if normalize + psi /= norm(psi) + end + end + return psi +end diff --git a/test/solvers/utilities/tree_graphs.jl b/test/solvers/utilities/tree_graphs.jl new file mode 100644 index 00000000..4fdff761 --- /dev/null +++ b/test/solvers/utilities/tree_graphs.jl @@ -0,0 +1,23 @@ +import Graphs as gr +import NamedGraphs as ng + +""" + build_tree + + Make a tree with central vertex (0,0) and + nbranch branches of nbranch_sites each. +""" +function build_tree(; nbranch=3, nbranch_sites=3) + g = ng.NamedGraph() + gr.add_vertex!(g, (0, 0)) + for branch in 1:nbranch, site in 1:nbranch_sites + gr.add_vertex!(g, (branch, site)) + end + for branch in 1:nbranch + gr.add_edge!(g, (0, 0)=>(branch, 1)) + for site in 2:nbranch_sites + gr.add_edge!(g, (branch, site-1)=>(branch, site)) + end + end + return g +end