Skip to content

Commit

Permalink
Merge pull request #41 from tensor4all/recycle-pivots
Browse files Browse the repository at this point in the history
Implement pivots recycling and bug fixes
  • Loading branch information
gianlucagrosso authored Feb 7, 2025
2 parents 104686a + 34f57c7 commit 6db8235
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 39 deletions.
81 changes: 62 additions & 19 deletions src/crossinterpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
checkbatchevaluatable::Bool
loginterval::Int
initialpivots::Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
recyclepivots::Bool
end

function Base.show(io::IO, obj::TCI2PatchCreator{T}) where {T}
Expand All @@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
obj.checkbatchevaluatable,
obj.loginterval,
obj.initialpivots,
obj.recyclepivots,
)
end

Expand All @@ -157,7 +159,8 @@ function TCI2PatchCreator(
ninitialpivot=5,
checkbatchevaluatable=false,
loginterval=10,
initialpivots=Vector{MultiIndex}[],
initialpivots=MultiIndex[],
recyclepivots=false,
)::TCI2PatchCreator{T} where {T}
#t1 = time_ns()
if projector === nothing
Expand All @@ -183,6 +186,7 @@ function TCI2PatchCreator(
checkbatchevaluatable,
loginterval,
initialpivots,
recyclepivots,
)
end

Expand All @@ -206,6 +210,7 @@ function _crossinterpolate2!(
verbosity::Int=0,
checkbatchevaluatable=false,
loginterval=10,
recyclepivots=false,
) where {T}
ncheckhistory = 3
ranks, errors = TCI.optimize!(
Expand All @@ -231,13 +236,45 @@ function _crossinterpolate2!(
ncheckhistory_ = min(ncheckhistory, length(errors))
maxbonddim_hist = maximum(ranks[(end - ncheckhistory_ + 1):end])

return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci), TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim
)
if recyclepivots
return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci),
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
_globalpivots(tci),
)

else
return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci),
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
)
end
end

# Generating global pivots from local ones
function _globalpivots(
tci::TCI.TensorCI2{T}; onlydiagonal=true
)::Vector{MultiIndex} where {T}
Isets = tci.Iset
Jsets = tci.Jset
L = length(Isets)
p = Set{MultiIndex}()
# Pivot matrices
for bondindex in 1:(L - 1)
if onlydiagonal
for (x, y) in zip(Isets[bondindex + 1], Jsets[bondindex])
push!(p, vcat(x, y))
end
else
for x in Isets[bondindex + 1], y in Jsets[bondindex]
push!(p, vcat(x, y))
end
end
end
return collect(p)
end

function createpatch(obj::TCI2PatchCreator{T}) where {T}
proj = obj.projector
fsubset = _FuncAdapterTCI2Subset(obj.f)

tci = if isapproxttavailable(obj.f)
Expand All @@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
end
tci
else
# Random initial pivots
initialpivots = MultiIndex[]
let
mask = [!isprojectedat(proj, n) for n in 1:length(proj)]
for idx in obj.initialpivots
idx_ = [[i] for i in idx]
if idx_ <= proj
push!(initialpivots, idx[mask])
end
if obj.recyclepivots
# First patching iteration: random pivots
if length(fsubset.localdims) == length(obj.localdims)
initialpivots = union(
obj.initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)
# Next iterations: recycle previously generated pivots
else
initialpivots = copy(obj.initialpivots)
end
else
initialpivots = union(
obj.initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)
end
append!(
initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)

if all(fsubset.(initialpivots) .== 0)
return PatchCreatorResult{T,TensorTrainState{T}}(nothing, true)
end
Expand All @@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
verbosity=obj.verbosity,
checkbatchevaluatable=obj.checkbatchevaluatable,
loginterval=obj.loginterval,
recyclepivots=obj.recyclepivots,
)
end

Expand All @@ -301,9 +343,9 @@ function adaptiveinterpolate(
verbosity=0,
maxbonddim=typemax(Int),
tolerance=1e-8,
initialpivots=Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
initialpivots=MultiIndex[], # Make it to Vector{MMultiIndex}?
recyclepivots=false,
)::ProjTTContainer{T} where {T}
t1 = time_ns()
creator = TCI2PatchCreator(
T,
f,
Expand All @@ -313,6 +355,7 @@ function adaptiveinterpolate(
verbosity,
ntry=10,
initialpivots=initialpivots,
recyclepivots=recyclepivots,
)
tmp = adaptiveinterpolate(creator, pordering; verbosity)
return reshape(tmp, f.sitedims)
Expand Down
37 changes: 32 additions & 5 deletions src/patching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ abstract type AbstractPatchCreator{T,M} end
mutable struct PatchCreatorResult{T,M}
data::Union{M,Nothing}
isconverged::Bool
resultpivots::Vector{MultiIndex}

function PatchCreatorResult{T,M}(
data::Union{M,Nothing}, isconverged::Bool, resultpivots::Vector{MultiIndex}
)::PatchCreatorResult{T,M} where {T,M}
return new{T,M}(data, isconverged, resultpivots)
end

function PatchCreatorResult{T,M}(
data::Union{M,Nothing}, isconverged::Bool
)::PatchCreatorResult{T,M} where {T,M}
return new{T,M}(data, isconverged, MultiIndex[])
end
end

function _reconst_prefix(projector::Projector, pordering::PatchOrdering)
Expand All @@ -63,10 +76,19 @@ function __taskfunc(creator::AbstractPatchCreator{T,M}, pordering; verbosity=0)
for ic in 1:creator.localdims[pordering.ordering[length(prefix) + 1]]
prefix_ = vcat(prefix, ic)
projector_ = makeproj(pordering, prefix_, creator.localdims)
#if verbosity > 0
##println("Creating a task for $(prefix_) ...")
#end
push!(newtasks, project(creator, projector_))

# Pivots are shorter, pordering index is in a different position
active_dims_ = findall(x -> x == [0], creator.projector.data)
pos_ = findfirst(x -> x == pordering.ordering[length(prefix) + 1], active_dims_)
pivots_ = [
copy(piv) for piv in filter(piv -> piv[pos_] == ic, patch.resultpivots)
]

if !isempty(pivots_)
deleteat!.(pivots_, pos_)
end

push!(newtasks, project(creator, projector_; pivots=pivots_))
end
return nothing, newtasks
end
Expand All @@ -77,14 +99,19 @@ function _zerott(T, prefix, po::PatchOrdering, localdims::Vector{Int})
return TensorTrain([zeros(T, 1, d, 1) for d in localdims_])
end

function project(obj::AbstractPatchCreator{T,M}, projector::Projector) where {T,M}
function project(
obj::AbstractPatchCreator{T,M},
projector::Projector;
pivots::Vector{MultiIndex}=MultiIndex[],
) where {T,M}
projector <= obj.projector || error(
"Projector $projector is not a subset of the original projector $(obj.f.projector)",
)

obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
obj_copy.projector = deepcopy(projector)
obj_copy.f = project(obj_copy.f, projector)
obj_copy.initialpivots = deepcopy(pivots)
return obj_copy
end

Expand Down
29 changes: 16 additions & 13 deletions src/projectable_evaluator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,28 +248,31 @@ function batchevaluateprj(
# Some of indices might be projected
NL = length(leftmmultiidxset[1])
NR = length(rightmmultiidxset[1])
L = length(obj)

NL + NR + M == length(obj) ||
error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(length(obj))")
NL + NR + M == L || error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(L)")

L = length(obj)
returnshape = projectedshape(obj.projector, NL + 1, L - NR)
result::Array{T,M + 2} = zeros(
T,
length(leftmmultiidxset),
prod.(obj.sitedims[(1 + NL):(L - NR)])...,
length(rightmmultiidxset),
T, length(leftmmultiidxset), returnshape..., length(rightmmultiidxset)
)
result[lmask, .., rmask] .= begin

projmask = map(
p -> p == 0 ? Colon() : p,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
)
slice = map(
p -> p == 0 ? Colon() : 1,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
)

result[lmask, slice..., rmask] .= begin
result_lrmask_multii = reshape(
result_lrmask,
size(result_lrmask)[1],
collect(Iterators.flatten(obj.sitedims[(1 + NL):(L - NR)]))...,
size(result_lrmask)[end],
)
projmask = map(
p -> p == 0 ? Colon() : p,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(length(obj) - NR)),
)
) # Gianluca - this step might be not needed. I leave it for safety
result_lrmask_multii[:, projmask..., :]
end
return result
Expand Down
Loading

0 comments on commit 6db8235

Please sign in to comment.