Skip to content

Commit

Permalink
Stop importing Graphs.contract and export our own symbol
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Dec 27, 2024
1 parent e783620 commit 1309df4
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 36 deletions.
1 change: 0 additions & 1 deletion ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Graphs

const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent

Expand Down
2 changes: 1 addition & 1 deletion ext/TenetGraphMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TenetGraphMakieExt

using Tenet
using GraphMakie
using Graphs
using Graphs: Graphs
using Makie
using Combinatorics: combinations
const NetworkLayout = GraphMakie.NetworkLayout
Expand Down
12 changes: 6 additions & 6 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using KeywordDispatch
using LinearAlgebra
using Graphs
using Graphs: Graphs

# Traits
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ struct Ansatz <: AbstractAnsatz
lattice::Lattice

function Ansatz(tn, lattice)
if !issetequal(lanes(tn), vertices(lattice))
if !issetequal(lanes(tn), Graphs.vertices(lattice))
throw(ArgumentError("Sites of the tensor network and the lattice must be equal"))
end
return new(tn, lattice)
Expand Down Expand Up @@ -137,14 +137,14 @@ end
Return the neighboring sites of a given [`Site`](@ref) in the [`Lattice`](@ref) of the [`AbstractAnsatz`](@ref) Tensor Network.
"""
Graphs.neighbors(tn::AbstractAnsatz, site::Site) = neighbors(lattice(tn), site)
Graphs.neighbors(tn::AbstractAnsatz, site::Site) = Graphs.neighbors(lattice(tn), site)

"""
has_edge(tn::AbstractAnsatz, a::Site, b::Site)
Check whether there is an edge between two [`Site`](@ref)s in the [`Lattice`](@ref) of the [`AbstractAnsatz`](@ref) Tensor Network.
"""
Graphs.has_edge(tn::AbstractAnsatz, a::Site, b::Site) = has_edge(lattice(tn), a, b)
Graphs.has_edge(tn::AbstractAnsatz, a::Site, b::Site) = Graphs.has_edge(lattice(tn), a, b)

"""
inds(tn::AbstractAnsatz; bond)
Expand All @@ -156,7 +156,7 @@ Return the index of the virtual bond between two [`Site`](@ref)s in a [`Abstract
@assert site1 sites(tn) "Site $site1 not found"
@assert site2 sites(tn) "Site $site2 not found"
@assert site1 != site2 "Sites must be different"
@assert has_edge(tn, site1, site2) "Sites must be neighbors"
@assert Graphs.has_edge(tn, site1, site2) "Sites must be neighbors"

tensor1 = tensors(tn; at=site1)
tensor2 = tensors(tn; at=site2)
Expand Down Expand Up @@ -477,7 +477,7 @@ function simple_update_2site!(::MixedCanonical, ψ::AbstractAnsatz, gate; kwargs
end

function simple_update_2site!(::NonCanonical, ψ::AbstractAnsatz, gate; kwargs...)
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"
@assert Graphs.has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
gate = copy(gate)
Expand Down
32 changes: 16 additions & 16 deletions src/Lattice.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Graphs
using Graphs: Graphs
using BijectiveDicts: BijectiveIdDict

struct LatticeEdge <: AbstractEdge{Site}
struct LatticeEdge <: Graphs.AbstractEdge{Site}
src::Site
dst::Site
end
Expand All @@ -23,7 +23,7 @@ A lattice is a graph where the vertices are [`Site`](@ref)s and the edges are vi
It is used for representing the topology of a [`Ansatz`](@ref) Tensor Network.
It fulfills the [`AbstractGraph`](https://juliagraphs.org/Graphs.jl/stable/core_functions/interface/) interface.
"""
struct Lattice <: AbstractGraph{Site}
struct Lattice <: Graphs.AbstractGraph{Site}
mapping::BijectiveIdDict{Site,Int}
graph::Graphs.SimpleGraph{Int}
end
Expand All @@ -43,7 +43,7 @@ Graphs.is_directed(::Type{Lattice}) = false
Return the vertices of the lattice; i.e. the list of [`Site`](@ref)s.
"""
function Graphs.vertices(lattice::Lattice)
return map(vertices(lattice.graph)) do vertex
return map(Graphs.vertices(lattice.graph)) do vertex
lattice.mapping'[vertex]
end
end
Expand All @@ -53,21 +53,21 @@ end
Return the edges of the lattice; i.e. pairs of [`Site`](@ref)s.
"""
Graphs.edges(lattice::Lattice) = LatticeEdgeIterator(edges(lattice.graph), lattice)
Graphs.edges(lattice::Lattice) = LatticeEdgeIterator(Graphs.edges(lattice.graph), lattice)

"""
Graphs.nv(::Lattice)
Return the number of vertices/[`Site`](@ref)s in the lattice.
"""
Graphs.nv(lattice::Lattice) = nv(lattice.graph)
Graphs.nv(lattice::Lattice) = Graphs.nv(lattice.graph)

"""
Graphs.ne(::Lattice)
Return the number of edges in the lattice.
"""
Graphs.ne(lattice::Lattice) = ne(lattice.graph)
Graphs.ne(lattice::Lattice) = Graphs.ne(lattice.graph)

"""
Graphs.has_vertex(lattice::Lattice, site::Site)
Expand All @@ -82,11 +82,11 @@ Graphs.has_vertex(lattice::Lattice, site::Site) = haskey(lattice.mapping, site)
Return `true` if the lattice has the given edge.
"""
Graphs.has_edge(lattice::Lattice, edge::LatticeEdge) = has_edge(lattice, edge.src, edge.dst)
Graphs.has_edge(lattice::Lattice, edge::LatticeEdge) = Graphs.has_edge(lattice, edge.src, edge.dst)
function Graphs.has_edge(lattice::Lattice, a::Site, b::Site)
return has_vertex(lattice, a) &&
has_vertex(lattice, b) &&
has_edge(lattice.graph, lattice.mapping[a], lattice.mapping[b])
return Graphs.has_vertex(lattice, a) &&
Graphs.has_vertex(lattice, b) &&
Graphs.has_edge(lattice.graph, lattice.mapping[a], lattice.mapping[b])
end

"""
Expand All @@ -95,9 +95,9 @@ end
Return the neighbors [`Site`](@ref)s of the given [`Site`](@ref).
"""
function Graphs.neighbors(lattice::Lattice, site::Site)
has_vertex(lattice, site) || throw(ArgumentError("site not in lattice"))
Graphs.has_vertex(lattice, site) || throw(ArgumentError("site not in lattice"))
vertex = lattice.mapping[site]
return map(neighbors(lattice.graph, vertex)) do neighbor
return map(Graphs.neighbors(lattice.graph, vertex)) do neighbor
lattice.mapping'[neighbor]
end
end
Expand All @@ -107,15 +107,15 @@ struct LatticeEdgeIterator <: Graphs.AbstractEdgeIter
lattice::Lattice
end

Graphs.ne(iterator::LatticeEdgeIterator) = ne(iterator.lattice)
Graphs.ne(iterator::LatticeEdgeIterator) = Graphs.ne(iterator.lattice)
Base.eltype(::Type{LatticeEdgeIterator}) = LatticeEdge
Base.length(iterator::LatticeEdgeIterator) = length(iterator.simpleit)
Base.in(e::LatticeEdge, it::LatticeEdgeIterator) = has_edge(it.lattice, src(e), src(dst))
Base.in(e::LatticeEdge, it::LatticeEdgeIterator) = Graphs.has_edge(it.lattice, Graphs.src(e), Graphs.src(dst))
Base.show(io::IO, iterator::LatticeEdgeIterator) = write(io, "LatticeEdgeIterator $(ne(iterator))")

function Base.iterate(iterator::LatticeEdgeIterator, state=nothing)
itres = isnothing(state) ? iterate(iterator.simpleit) : iterate(iterator.simpleit, state)
isnothing(itres) && return nothing
edge, state = itres
return LatticeEdge(iterator.lattice.mapping'[src(edge)], iterator.lattice.mapping'[dst(edge)]), state
return LatticeEdge(iterator.lattice.mapping'[Graphs.src(edge)], iterator.lattice.mapping'[Graphs.dst(edge)]), state
end
6 changes: 3 additions & 3 deletions src/MPS.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Random
using LinearAlgebra
using Graphs
using Graphs: Graphs
using BijectiveDicts: BijectiveIdDict

abstract type AbstractMPO <: AbstractAnsatz end
Expand Down Expand Up @@ -90,7 +90,7 @@ function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check=true)

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
graph = Graphs.path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
Expand Down Expand Up @@ -221,7 +221,7 @@ function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n))
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
graph = Graphs.path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
Expand Down
6 changes: 3 additions & 3 deletions src/Product.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearAlgebra
using Graphs
using Graphs: Graphs

"""
Product <: AbstractAnsatz
Expand Down Expand Up @@ -33,7 +33,7 @@ function Product(arrays::AbstractArray{<:AbstractVector})

sitemap = Dict(Site(i) => symbols[i] for i in eachindex(arrays))
qtn = Quantum(TensorNetwork(_tensors), sitemap)
graph = Graph(n)
graph = Graphs.Graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
Expand All @@ -55,7 +55,7 @@ function Product(arrays::AbstractArray{<:AbstractMatrix})
Dict(Site(i) => symbols[i][2] for i in eachindex(arrays)),
)
qtn = Quantum(TensorNetwork(_tensors), sitemap)
graph = Graph(n)
graph = Graphs.Graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
Expand Down
1 change: 0 additions & 1 deletion src/Tenet.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Tenet

import EinExprs: inds
import Graphs: contract

include("Helpers.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra
using ScopedValues
using Serialization
using KeywordDispatch
using Graphs
using Graphs: Graphs

mutable struct CachedField{T}
isvalid::Bool
Expand Down
6 changes: 3 additions & 3 deletions test/Ansatz_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using LinearAlgebra
@test zero(ansatz) == Ansatz(zero(qtn), lattice)
@test Tenet.lattice(ansatz) == lattice
@test isempty(neighbors(ansatz, site"1"))
@test !Tenet.has_edge(ansatz, site"1", site"2")
@test !has_edge(ansatz, site"1", site"2")

# some AbstractQuantum methods
@test inds(ansatz; at=site"1") == :i
Expand Down Expand Up @@ -63,8 +63,8 @@ using LinearAlgebra
@test issetequal(neighbors(ansatz, site"1"), [site"2"])
@test issetequal(neighbors(ansatz, site"2"), [site"1"])

@test Tenet.has_edge(ansatz, site"1", site"2")
@test Tenet.has_edge(ansatz, site"2", site"1")
@test has_edge(ansatz, site"1", site"2")
@test has_edge(ansatz, site"2", site"1")

@test inds(ansatz; bond=(site"1", site"2")) == :i

Expand Down
2 changes: 1 addition & 1 deletion test/Product_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "Product ansatz" begin
@testset "Product" begin
using LinearAlgebra

# TODO test `Product` with `Scalar` socket
Expand Down
1 change: 1 addition & 0 deletions test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testset "TensorNetwork" begin
using Serialization
using Graphs: neighbors

@testset "Constructors" begin
@testset "empty" begin
Expand Down

0 comments on commit 1309df4

Please sign in to comment.