Skip to content

Commit

Permalink
WP: adaptiveinterpolate works
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed May 25, 2024
1 parent 563c2f0 commit 4416097
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 27 deletions.
107 changes: 81 additions & 26 deletions src/crossinterpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ end

Base.length(po::PatchOrdering) = length(po.ordering)


"""
n is the length of the prefix.
"""
Expand Down Expand Up @@ -73,14 +72,12 @@ function (obj::_FuncAdapterTCI2Subset{T})(
NR = length(rightindexset_fulllen[1])
M_ = length(obj.f) - NL - NR
projected = [
isprojectedat(obj.f.projector, n) ? 1 : Colon() for
n in (NL + 1):(orgL - NR)
isprojectedat(obj.f.projector, n) ? 1 : Colon() for n in (NL + 1):(orgL - NR)
]
res = batchevaluateprj(obj.f, leftindexset_fulllen, rightindexset_fulllen, Val(M_))
return res[:, projected..., :]
end


"""
leftindexset: Vector of indices on unprojected indices
Returns: Vector of indices on projected and unprojected indices
Expand Down Expand Up @@ -121,26 +118,33 @@ function fulllength_rightindexset(
return collect(reverse.(r))
end

function _reconst_prefix(projector::Projector, pordering::PatchOrdering)
np = Base.sum((isprojectedat(projector, n) for n in 1:length(projector)))
return [Base.only(projector[n]) for n in pordering.ordering[1:np]]
end

function adaptiveinterpolate(
creator::AbstractPatchCreator{T,M},
pordering::PatchOrdering;
sleep_time::Float64=1e-6,
maxnleaves=100,
verbosity=0,
)::Dict{Vector{MultiIndex},M} where {T,M}
leaves = Dict{Vector{Int},Union{Task,PatchCreatorResult{T,M}}}()
)::Dict{Projector,M} where {T,M}
leaves = Dict{AbstractPatchCreator, Union{Task,PatchCreatorResult{T,M}}}()
sitedims = [[x] for x in creator.localdims]

# Add root
leaves[[]] = createpatch(creator, pordering, Vector{Int}[])
leaves[creator] = createpatch(creator)

while true
sleep(sleep_time) # Not to run the loop too quickly

done = true
newtasks = Dict{Vector{Int},Task}()
for (prefix, leaf) in leaves
newtasks = Dict{AbstractPatchCreator{T,M},Task}()
for (pcreator, leaf) in leaves
# If the task is done, fetch the result, which
# will be analyzed in the next loop.
prefix::Vector{Int} = _reconst_prefix(pcreator.projector, pordering)
if leaf isa Task
if istaskdone(leaf)
if verbosity > 0
Expand All @@ -151,7 +155,7 @@ function adaptiveinterpolate(
err_msg = sprint(showerror, fetched.captured)
error("Error in creating a patch for $(prefix): $err_msg")
end
leaves[prefix] = fetched
leaves[pcreator] = fetched
end
done = false
continue
Expand All @@ -161,19 +165,20 @@ function adaptiveinterpolate(

if !leaf.isconverged && length(leaves) < maxnleaves
done = false
delete!(leaves, prefix)
delete!(leaves, pcreator)

for ic in 1:creator.localdims[pordering.ordering[length(prefix) + 1]]
prefix_ = vcat(prefix, ic)
projector_ = makeproj(pordering, prefix_, pcreator.localdims)
pcreator_child = project(pcreator, projector_)

if verbosity > 0
println("Creating a task for $(prefix_) ...")
end
t = @task fetch(
@spawnat :any createpatch(
creator, pordering, [[x] for x in prefix_]
)
@spawnat :any createpatch(pcreator_child)
)
newtasks[prefix_] = t
newtasks[pcreator_child] = t
schedule(t)
end
end
Expand All @@ -189,9 +194,9 @@ function adaptiveinterpolate(
end
end

leaves_done = Dict{Vector{Vector{Int}},M}()
leaves_done = Dict{Projector,M}()
for (k, v) in leaves
leaves_done[[[x] for x in k]] = isnothing(v.data) ? _zerott(T, k, pordering, creator.localdims) : v.data
leaves_done[k.f.projector] = isnothing(v.data) ? _zerott(T, k, pordering, creator.localdims) : v.data
end

return leaves_done
Expand All @@ -209,6 +214,7 @@ end
mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}}
f::ProjectableEvaluator{T}
localdims::Vector{Int}
projector::Projector
rtol::Float64
maxbonddim::Int
verbosity::Int
Expand All @@ -221,10 +227,29 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
initialpivots::Vector{MultiIndex}
end

function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} where {T}
return TCI2PatchCreator{T}(
obj.f,
obj.localdims,
obj.projector,
obj.rtol,
obj.maxbonddim,
obj.verbosity,
obj.tcikwargs,
obj.maxval,
obj.atol,
obj.ninitialpivot,
obj.checkbatchevaluatable,
obj.loginterval,
obj.initialpivots,
)
end

function TCI2PatchCreator(
::Type{T},
f,
localdims::Vector{Int};
localdims::Vector{Int},
projector::Union{Projector,Nothing} = nothing;
rtol::Float64=1e-8,
maxbonddim::Int=100,
verbosity::Int=0,
Expand All @@ -236,9 +261,18 @@ function TCI2PatchCreator(
initialpivots=Vector{MultiIndex}[],
)::TCI2PatchCreator{T} where {T}
maxval, _ = _estimate_maxval(f, localdims; ntry=ntry)
if projector === nothing
projector = Projector([[0] for _ in localdims], [[x] for x in localdims])
end

if !(f.projector <= projector)
f = project(f, projector)
end

return TCI2PatchCreator{T}(
f,
localdims,
projector,
rtol,
maxbonddim,
verbosity,
Expand Down Expand Up @@ -294,17 +328,26 @@ function _crossinterpolate2(
)
end

function createpatch(
obj::TCI2PatchCreator{T}, pordering::PatchOrdering, prefix::Vector{Vector{Int}}
) where {T}
proj = makeproj(pordering, prefix, sitedims)
function project(obj::TCI2PatchCreator{T}, projector::Projector)::TCI2PatchCreator{T} where {T}
projector <= obj.projector || error(
"Projector $projector is not a subset of the original projector $(obj.f.projector)"
)

obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
obj_copy.projector = deepcopy(projector)
#if !(obj_copy.f.projector <= projector)
obj_copy.f = project(obj_copy.f, projector)
#end
return obj_copy
end

function createpatch(obj::TCI2PatchCreator{T}) where {T}
proj = obj.projector
fsubset = _FuncAdapterTCI2Subset(obj.f)

initialpivots = MultiIndex[]
let
mask = maskactiveindices(pordering, length(prefix))
sitedims = [[Base.only(d)] for d in obj.localdims]
proj = makeproj(pordering, prefix, sitedims)
mask = [!isprojectedat(proj, n) for n in 1:length(proj)]
for idx in obj.initialpivots
idx_ = [[i] for i in idx]
if idx_ <= proj
Expand All @@ -321,7 +364,7 @@ function createpatch(
return _crossinterpolate2(
T,
fsubset,
fprj.localdims,
fsubset.localdims,
initialpivots,
obj.atol;
maxbonddim=obj.maxbonddim,
Expand All @@ -346,6 +389,18 @@ function _estimate_maxval(f, localdims; ntry=100)
return maxval, pivot
end


function makeproj(
po::PatchOrdering, prefix::Vector{Int}, localdims::Vector{Int}
)
data = [[0] for _ in localdims]
for (i, n) in enumerate(po.ordering[1:length(prefix)])
data[n][1] = prefix[i]
end
return Projector(data, [[x] for x in localdims])
end


function makeproj(
po::PatchOrdering, prefix::Vector{Vector{Int}}, sitedims::Vector{Vector{Int}}
)
Expand Down
7 changes: 6 additions & 1 deletion src/projectable_evaluator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ struct ProjectableEvaluatorAdapter{T} <: ProjectableEvaluator{T}
f::TCI.BatchEvaluator{T}
sitedims::Vector{Vector{Int}}
projector::Projector
function ProjectableEvaluatorAdapter{T}(
f::TCI.BatchEvaluator{T}, sitedims::Vector{Vector{Int}}, projector::Projector
) where {T}
new{T}(f, sitedims, projector)
end
function ProjectableEvaluatorAdapter{T}(
f::TCI.BatchEvaluator{T}, sitedims::Vector{Vector{Int}}
) where {T}
Expand Down Expand Up @@ -198,5 +203,5 @@ end
function project(
obj::ProjectableEvaluatorAdapter{T}, prj::Projector
)::ProjectableEvaluator{T} where {T}
return ProjectableEvaluatorAdapter(obj.f, obj.sitedims, prj)
return ProjectableEvaluatorAdapter{T}(obj.f, obj.sitedims, prj)
end
6 changes: 6 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,9 @@ function _contract(
amat * bmat, _getindex(size(a), rest_idx_a)..., _getindex(size(b), rest_idx_b)...
)
end

function shallowcopy(original)
fieldnames = Base.fieldnames(typeof(original))
new_fields = [Base.copy(getfield(original, name)) for name in fieldnames]
return (typeof(original))(new_fields...)
end

0 comments on commit 4416097

Please sign in to comment.