Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed May 28, 2024
1 parent 23befd6 commit 5d0d23b
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 56 deletions.
1 change: 0 additions & 1 deletion src/TCIAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ include("projtensortrain.jl")
include("mul.jl")
include("crossinterpolate.jl")


#include("util.jl")
#include("tensor.jl")
#include("adapter.jl")
Expand Down
12 changes: 8 additions & 4 deletions src/bak/dtensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ mutable struct DTensorTrain{T} <: ProjectableEvaluator{T}
end

function DTensorTrain(sitetensors::AbstractVector{Array{T,N}}) where {T,N}
sitedims = [collect(size(t)[2:end-1]) for t in sitetensors]
DTensorTrain{T}(sitetensors, sitedims,
sitetensors_fused = [reshape(x, size(x, 1), :, size(x)[end]) for x in obj.sitetensors]
sitedims = [collect(size(t)[2:(end - 1)]) for t in sitetensors]
return DTensorTrain{T}(
sitetensors,
sitedims;
sitetensors_fused=[
reshape(x, size(x, 1), :, size(x)[end]) for x in obj.sitetensors
],
)
end

Expand All @@ -32,4 +36,4 @@ function (obj::DTensorTrain{T})(
sitetensors_fused = [reshape(x, size(x, 1), :, size(x)[end]) for x in obj.sitetensors]

return TensorTrain(sitetensors_fused)(leftindexset_, rightindexset_, Val(M))
end
end
20 changes: 10 additions & 10 deletions src/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ struct LazyMatrixMul{T} <: ProjectableEvaluator{T}
b::ProjTensorTrain{T}
end

function LazyMatrixMul{T}(a::ProjTensorTrain{T}, b::ProjTensorTrain{T}; coeff=one(T)) where {T}
function LazyMatrixMul{T}(
a::ProjTensorTrain{T}, b::ProjTensorTrain{T}; coeff=one(T)
) where {T}
# This restriction is due to simulicity and to be removed.
all(length.(a.sitedims) .== 2) || error("The number of site indices must be 2")
all(length.(b.sitedims) .== 2) || error("The number of site indices must be 2")
Expand All @@ -30,17 +32,13 @@ function LazyMatrixMul(a::ProjTensorTrain{T}, b::ProjTensorTrain{T}; coeff=one(1
end

function project(
obj::LazyMatrixMul{T},
prj::Projector;
kwargs...
obj::LazyMatrixMul{T}, prj::Projector; kwargs...
)::LazyMatrixMul{T} where {T}
projector_a_new = Projector(
[[x[1], y[2]] for (x, y) in zip(prj, obj.a.projector.sitedims)],
obj.a.sitedims
[[x[1], y[2]] for (x, y) in zip(prj, obj.a.projector.sitedims)], obj.a.sitedims
)
projector_b_new = Projector(
[[x[1], y[2]] for (x, y) in zip(obj.b.projector.sitedims, prj)],
obj.b.sitedims
[[x[1], y[2]] for (x, y) in zip(obj.b.projector.sitedims, prj)], obj.b.sitedims
)
obj.a = project(obj.a, projector_a_new; kwargs...)
obj.b = project(obj.b, projector_b_new; kwargs...)
Expand Down Expand Up @@ -86,6 +84,8 @@ function batchevaluateprj(
isprojectedat(obj.projector, n) ? 1 : prod(obj.sitedims[n]) for
n in (NL + 1):(L - NR)
]
res = TCI.batchevaluate(obj.contraction, leftindexset_, rightindexset_, Val(M), projector)
res = TCI.batchevaluate(
obj.contraction, leftindexset_, rightindexset_, Val(M), projector
)
return reshape(res, length(leftindexset), returnshape..., length(rightindexset))
end
end
33 changes: 19 additions & 14 deletions src/projectable_evaluator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ function (obj::ProjectableEvaluator{T})(
NR = length(rightindexset[1])
L = length(obj)

results = zeros(T, length(leftindexset), prod.(obj.sitedims[(NL + 1):(L - NR)])..., length(rightindexset))
results = zeros(
T,
length(leftindexset),
prod.(obj.sitedims[(NL + 1):(L - NR)])...,
length(rightindexset),
)

slice = (
isprojectedat(obj.projector, n) ? _lineari(sitedims, projector[n]) : Colon() for
Expand Down Expand Up @@ -154,7 +159,7 @@ struct ProjectableEvaluatorAdapter{T} <: ProjectableEvaluator{T}
function ProjectableEvaluatorAdapter{T}(
f::TCI.BatchEvaluator{T}, sitedims::Vector{Vector{Int}}, projector::Projector
) where {T}
new{T}(f, sitedims, projector)
return new{T}(f, sitedims, projector)
end
function ProjectableEvaluatorAdapter{T}(
f::TCI.BatchEvaluator{T}, sitedims::Vector{Vector{Int}}
Expand All @@ -165,18 +170,14 @@ end

Base.length(obj::ProjectableEvaluatorAdapter) = length(obj.sitedims)

function makeprojectable(
::Type{T}, f::Function, localdims::Vector{Int}
) where {T}
function makeprojectable(::Type{T}, f::Function, localdims::Vector{Int}) where {T}
return ProjectableEvaluatorAdapter{T}(
f isa TCI.BatchEvaluator ? f : TCI.makebatchevaluatable(T, f, localdims),
[[x] for x in localdims]
[[x] for x in localdims],
)
end

function (obj::ProjectableEvaluatorAdapter{T})(
indexset::MMultiIndex
)::T where {T}
function (obj::ProjectableEvaluatorAdapter{T})(indexset::MMultiIndex)::T where {T}
return indexset <= obj.projector ? obj.f(lineari(obj.sitedims, indexset)) : zero(T)
end

Expand All @@ -200,13 +201,18 @@ function batchevaluateprj(
NL = length(leftindexset[1])
NR = length(rightindexset[1])
projmask = [
isprojectedat(obj.projector, n) ? obj.projector[n] : Colon()
for n in 1+NL:length(obj)-NR
isprojectedat(obj.projector, n) ? obj.projector[n] : Colon() for
n in (1 + NL):(length(obj) - NR)
]

tmp = result_within_proj[:, projmask..., :]
L = length(obj)
result = zeros(T, length(leftindexset), prod.(obj.sitedims[1+NL:L-NR])..., length(rightindexset))
result = zeros(
T,
length(leftindexset),
prod.(obj.sitedims[(1 + NL):(L - NR)])...,
length(rightindexset),
)
result[lmask, .., rmask] .= tmp
return result
end
Expand All @@ -217,10 +223,9 @@ function project(
return ProjectableEvaluatorAdapter{T}(obj.f, obj.sitedims, prj)
end


function fulltensor(obj::ProjectableEvaluator)
localdims = collect(prod.(obj.sitedims))
r = [obj(collect(Tuple(i))) for i in CartesianIndices(Tuple(localdims))]
returnsize = collect(Iterators.flatten(obj.sitedims))
return reshape(r, returnsize...)
end
end
17 changes: 12 additions & 5 deletions src/projector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ end

# Extract ilegg-th index from the projector
#function only(p::Projector, ilegg::Int)::Projector
#data = [[p.data[l][ilegg]] for l in 1:length(p)]
#sitedims = [[p.sitedims[l][ilegg]] for l in 1:length(p)]
#return Projector(data, sitedims)
#data = [[p.data[l][ilegg]] for l in 1:length(p)]
#sitedims = [[p.sitedims[l][ilegg]] for l in 1:length(p)]
#return Projector(data, sitedims)
#end

Base.:(==)(a::Projector, b::Projector)::Bool = (a.data == b.data)
Expand Down Expand Up @@ -138,8 +138,6 @@ function Base.reshape(
return Projector(newprojectordata, dims)
end



function isprojectedat(p::Projector, n::Int)::Bool
if all(p.data[n] .== 0)
return false
Expand Down Expand Up @@ -174,3 +172,12 @@ fullindices(projector, indexset::MultiIndex)::MultiIndex = lineari(
projector.sitedims, fullindices(projector, multii(projector.sitedims, indexset))
)

function projectedshape(projector::Projector, startidx::Int, lastidx::Int)::Vector{Int}
res = Int[
prod(
projector[n][s] > 0 ? 1 : projector.sitedims[n][s] for
s in eachindex(projector.data[n])
) for n in startidx:lastidx
]
return res
end
2 changes: 1 addition & 1 deletion src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function allequal(collection)
end

function Not(index::Int, length::Int)
return vcat(1:index-1, index+1:length)
return vcat(1:(index - 1), (index + 1):length)
end

function _multii(sitedims::Vector{Int}, i::Int)::Vector{Int}
Expand Down
21 changes: 10 additions & 11 deletions test/bak/dtensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@ import TCIAlgorithms as TCIA
tt = TCIA.DTensorTrain([rand(2, 3, 5, 4), rand(4, 3, 2, 4), rand(4, 2, 5, 4)])
@test tt.sitedims == [[3, 5], [3, 2], [2, 5]]

tt_ref = TCI.TensorTrain([reshape(x, size(x, 1), :, size(x)[end]) for x in tt.sitetensors])
tt_ref = TCI.TensorTrain([
reshape(x, size(x, 1), :, size(x)[end]) for x in tt.sitetensors
])

leftindexset = [
[[1, 1]],
[[1, 2]]
]
rightindexset = [
[[1, 1]],
[[1, 2]]
]
leftindexset = [[[1, 1]], [[1, 2]]]
rightindexset = [[[1, 1]], [[1, 2]]]

NL = 1
NR = 1
leftindexset_ = [TCIA.lineari(tt.sitedims[1:NL], x) for x in leftindexset]
rightindexset_ = [TCIA.lineari(tt.sitedims[(end - NR + 1):end], x) for x in rightindexset]
rightindexset_ = [
TCIA.lineari(tt.sitedims[(end - NR + 1):end], x) for x in rightindexset
]

@test tt(leftindexset, rightindexset, Val(1)) tt_ref(leftindexset_, rightindexset_, Val(1))
@test tt(leftindexset, rightindexset, Val(1))
tt_ref(leftindexset_, rightindexset_, Val(1))
end
end
4 changes: 2 additions & 2 deletions test/crossinterpolate_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import TCIAlgorithms as TCIA

p = TCIA.Projector([[0, 0], [0, 0], [2, 2], [0, 0], [0, 0], [1, 1]], sitedims)

ptt = TCIA.ProjTensorTrain(tt, p)
ptt = TCIA.project(TCIA.ProjTensorTrain(tt), p)

ptt_wrapper = TCIA._FuncAdapterTCI2Subset(ptt)
@test ptt_wrapper.localdims == [4, 4, 4, 4]
Expand Down Expand Up @@ -51,4 +51,4 @@ import TCIAlgorithms as TCIA
p = TCIA.Projector([[1, 1], [0, 0], [1, 1], [0, 0], [1, 1]], sitedims)
@test TCIA.fulllength_rightindexset(p, [Int[]]) == [[1]]
end
end
end
2 changes: 1 addition & 1 deletion test/mul_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ import TCIAlgorithms: Projector, project, ProjTensorTrain, LazyMatrixMul
ab_ref = TCI.contract_naive(a_tt, b_tt)

@test TCIA.fulltensor(ab) TCIA.fulltensor(ProjTensorTrain(ab_ref))
end
end
5 changes: 2 additions & 3 deletions test/projectable_evaluator_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import TCIAlgorithms as TCIA

import TCIAlgorithms: Projector

import QuanticsGrids:
DiscretizedGrid, quantics_to_origcoord, origcoord_to_quantics
import QuanticsGrids: DiscretizedGrid, quantics_to_origcoord, origcoord_to_quantics
import QuanticsGrids as QG

@testset "makeprojectable" begin
Expand All @@ -29,4 +28,4 @@ import QuanticsGrids as QG

@test vec(pqf(leftindexset_, rightindexset_, Val(R - 3)))
vec([qf([l..., i, r...]) for l in leftindexset, i in 1:4, r in rightindexset])
end
end
9 changes: 5 additions & 4 deletions test/projector_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ import TCIAlgorithms: Projector
Projector([[0, 0], [2, 1, 1]], sitedimsnew)
end

@testset "isprojectedat" begin
@testset "projectedshape" begin
sitedims = [[2, 2], [2, 2], [2, 2]]

p = TCIA.Projector([[0, 0], [1, 1], [0, 0]], sitedims)
@test [TCIA.isprojectedat(p, n) for n in 1:length(p)] == [false, true, false]
@test TCIA.projectedshape(p, 1, 3) == [4, 1, 4]

p = TCIA.Projector([[0, 1]], sitedims)
@test_throws ErrorException TCIA.isprojectedat(p, 1)
p = TCIA.Projector([[0, 0], [1, 0], [0, 0]], sitedims)
@test TCIA.projectedshape(p, 1, 3) == [4, 2, 4]
end

@testset "fullindices" begin
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ include("projector_tests.jl")
include("projectable_evaluator_tests.jl")
include("projtensortrain_tests.jl")
include("mul_tests.jl")
include("crossinterpolate_tests.jl")

0 comments on commit 5d0d23b

Please sign in to comment.