Skip to content

Commit

Permalink
WP: ProjTensorTrain support abitrary numbers of site indices
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed May 26, 2024
1 parent 4416097 commit 2466c16
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
28 changes: 12 additions & 16 deletions src/crossinterpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ function adaptiveinterpolate(
creator::AbstractPatchCreator{T,M},
pordering::PatchOrdering;
sleep_time::Float64=1e-6,
maxnleaves=100,
maxnleaves=typemax(Int),
verbosity=0,
)::Dict{Projector,M} where {T,M}
leaves = Dict{AbstractPatchCreator, Union{Task,PatchCreatorResult{T,M}}}()
sitedims = [[x] for x in creator.localdims]
leaves = Dict{AbstractPatchCreator,Union{Task,PatchCreatorResult{T,M}}}()

# Add root
leaves[creator] = createpatch(creator)
Expand Down Expand Up @@ -175,9 +174,7 @@ function adaptiveinterpolate(
if verbosity > 0
println("Creating a task for $(prefix_) ...")
end
t = @task fetch(
@spawnat :any createpatch(pcreator_child)
)
t = @task fetch(@spawnat :any createpatch(pcreator_child))
newtasks[pcreator_child] = t
schedule(t)
end
Expand All @@ -196,7 +193,8 @@ function adaptiveinterpolate(

leaves_done = Dict{Projector,M}()
for (k, v) in leaves
leaves_done[k.f.projector] = isnothing(v.data) ? _zerott(T, k, pordering, creator.localdims) : v.data
leaves_done[k.f.projector] =
isnothing(v.data) ? _zerott(T, k, pordering, creator.localdims) : v.data
end

return leaves_done
Expand Down Expand Up @@ -249,7 +247,7 @@ function TCI2PatchCreator(
::Type{T},
f,
localdims::Vector{Int},
projector::Union{Projector,Nothing} = nothing;
projector::Union{Projector,Nothing}=nothing;
rtol::Float64=1e-8,
maxbonddim::Int=100,
verbosity::Int=0,
Expand Down Expand Up @@ -328,15 +326,17 @@ function _crossinterpolate2(
)
end

function project(obj::TCI2PatchCreator{T}, projector::Projector)::TCI2PatchCreator{T} where {T}
function project(
obj::TCI2PatchCreator{T}, projector::Projector
)::TCI2PatchCreator{T} where {T}
projector <= obj.projector || error(
"Projector $projector is not a subset of the original projector $(obj.f.projector)"
"Projector $projector is not a subset of the original projector $(obj.f.projector)",
)

obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
obj_copy.projector = deepcopy(projector)
#if !(obj_copy.f.projector <= projector)
obj_copy.f = project(obj_copy.f, projector)
obj_copy.f = project(obj_copy.f, projector)
#end
return obj_copy
end
Expand Down Expand Up @@ -389,18 +389,14 @@ function _estimate_maxval(f, localdims; ntry=100)
return maxval, pivot
end


function makeproj(
po::PatchOrdering, prefix::Vector{Int}, localdims::Vector{Int}
)
function makeproj(po::PatchOrdering, prefix::Vector{Int}, localdims::Vector{Int})
data = [[0] for _ in localdims]
for (i, n) in enumerate(po.ordering[1:length(prefix)])
data[n][1] = prefix[i]
end
return Projector(data, [[x] for x in localdims])
end


function makeproj(
po::PatchOrdering, prefix::Vector{Vector{Int}}, sitedims::Vector{Vector{Int}}
)
Expand Down
6 changes: 6 additions & 0 deletions src/projectable_evaluator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ function sum(obj::ProjectableEvaluator{T})::T where {T}
return zero(T)
end

"""
Project the object on the overlap of `prj` and `obj.projector`.
The requirement for the implementation is that
the projector of the returned object is a subset of `prj`.
"""
function project(
obj::ProjectableEvaluator{T}, prj::Projector
)::ProjectableEvaluator{T} where {T}
Expand Down
3 changes: 3 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ function allequal(collection)
return all(collection .== c)
end

function Not(index::Int, length::Int)
return vcat(1:index-1, index+1:length)
end

function _multii(sitedims::Vector{Int}, i::Int)::Vector{Int}
i <= prod(sitedims) || error("Index out of range $i, $sitedims")
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ include("util_tests.jl")
include("projector_tests.jl")
include("projectable_evaluator_tests.jl")
include("projtensortrain_tests.jl")
include("mul_tests.jl")
#include("mul_tests.jl")
6 changes: 6 additions & 0 deletions test/util_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import TCIAlgorithms as TCIA
end
end

@testset "Not" begin
A = [1, 2, 3]
A[collect(TCIA.Not(1, 3))] == A[2:3]
A[collect(TCIA.Not(2, 3))] == [A[1], A[3]]
end

@testset "findinitialpivots" begin
R = 8
localdims = fill(2, R)
Expand Down

0 comments on commit 2466c16

Please sign in to comment.