Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TensorNetwork internals to incidence matrix representation #120

Merged
merged 13 commits into from
Nov 10, 2023
Merged
34 changes: 12 additions & 22 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,41 +28,31 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull
@non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)

function ChainRulesCore.ProjectTo(tn::T) where {T<:AbstractTensorNetwork}
# TODO create function to extract extra fields
fields = map(fieldnames(T)) do fieldname
if fieldname === :tensors
:tensors => ProjectTo(tn.tensors)
else
fieldname => getfield(tn, fieldname)
end
end
ProjectTo{T}(; fields...)
ProjectTo{T}(; tensors = ProjectTo(tensors(tn)))
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:AbstractTensorNetwork}
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
function (projector::ProjectTo{T})(dx::T) where {T<:AbstractTensorNetwork}
Tangent{TensorNetwork}(tensormap = projector.tensors(tensors(dx)), indexmap = NoTangent())
end

function (projector::ProjectTo{T})(dx::Tangent{T}) where {T<:AbstractTensorNetwork}
dx.tensormap isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensormap = projector.tensors(dx.tensors), indexmap = NoTangent())
end

function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:AbstractTensorNetwork}
# TODO match tensors by indices
tensors = map(+, tensors(x), Δ.tensors)
tensors = map(+, tensors(x), Δ.tensormap)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
T(tensors)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:AbstractTensorNetwork}, tensors)
T(tensors), Tangent{TensorNetwork}(tensors = Δ)
T(tensors), Tangent{TensorNetwork}(tensormap = Δ, indexmap = NoTangent())
end

TensorNetwork_pullback(Δ::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors)
TensorNetwork_pullback(Δ::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensormap)
TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
function ChainRulesCore.rrule(T::Type{<:AbstractTensorNetwork}, tensors)
T(tensors), TensorNetwork_pullback
Expand Down
5 changes: 4 additions & 1 deletion ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ using ChainRulesTestUtils
using Random

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::T) where {T<:AbstractTensorNetwork}
return Tangent{T}(tensors = [ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
return Tangent{T}(
tensormap = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)],
indexmap = NoTangent(),
)
end

end
15 changes: 2 additions & 13 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,8 @@ using Tenet: AbstractTensorNetwork
using FiniteDifferences

function FiniteDifferences.to_vec(x::T) where {T<:AbstractTensorNetwork}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end
x_vec, back = to_vec(tensors(x))
TensorNetwork_from_vec(v) = T(back(v))

return x_vec, TensorNetwork_from_vec
end
Expand Down
12 changes: 8 additions & 4 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,23 @@ function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::AbstractTensorNetw
hypermap = Tenet.hyperflatten(tn)
tn = transform(tn, Tenet.HyperindConverter)

tensormap = IdDict(tensor => i for (i, tensor) in enumerate(tensors(tn)))

# TODO how to mark multiedges? (i.e. parallel edges)
graph = SimpleGraph([Edge(tensors...) for (_, tensors) in tn.indices if length(tensors) > 1])
graph = SimpleGraph([
Edge(map(Base.Fix1(getindex, tensormap), tensors)...) for (_, tensors) in tn.indexmap if length(tensors) > 1
])

# TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations
copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), keys(hypermap)), tensors(tn))
ghostnodes = map(inds(tn, :open)) do ind
ghostnodes = map(inds(tn, :open)) do index
# create new ghost node
add_vertex!(graph)
node = nv(graph)

# connect ghost node
tensor = only(tn.indices[ind])
add_edge!(graph, node, tensor)
tensor = only(tn.indexmap[index])
add_edge!(graph, node, tensormap[tensor])

return node
end
Expand Down
Loading