From 8e0c4339e1733f25e435654d5a12894ddf14a0a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 00:47:55 -0400 Subject: [PATCH 01/19] Refactor `Product` on top of new `Ansatz` type --- src/Ansatz/Product.jl | 84 +++++++++++++++++++++++++++++++++++++++++++ src/Tenet.jl | 2 ++ 2 files changed, 86 insertions(+) create mode 100644 src/Ansatz/Product.jl diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl new file mode 100644 index 000000000..fd6c07d9a --- /dev/null +++ b/src/Ansatz/Product.jl @@ -0,0 +1,84 @@ +using LinearAlgebra +using Graphs +using MetaGraphsNext + +struct Product <: AbstractAnsatz + tn::Ansatz +end + +Ansatz(tn::Product) = tn.tn + +Base.copy(x::Product) = Product(copy(Ansatz(x))) + +Base.similar(x::Product) = Product(similar(Ansatz(x))) +Base.zero(x::Product) = Product(zero(Ansatz(x))) + +function Product(arrays::Vector{<:AbstractVector}) + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:n] + _tensors = map(enumerate(arrays)) do (i, array) + Tensor(array, [symbols[i]]) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) +end + +function Product(arrays::Vector{<:AbstractMatrix}) + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] + _tensors = map(enumerate(arrays)) do (i, array) + Tensor(array, [symbols[i + n], symbols[i]], []) + end + + sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) +end + +function Base.zeros(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) + return Product(fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) +end + +function Base.ones(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) + return Product(fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n)) +end + +LinearAlgebra.norm(tn::Product, p::Real=2) = LinearAlgebra.norm(socket(tn), tn, p) +function LinearAlgebra.norm(::Union{State,Operator}, tn::Product, p::Real) + return mapreduce(*, tensors(tn)) do tensor + norm(tensor, p) + end^(1//p) +end + +LinearAlgebra.opnorm(tn::Product, p::Real=2) = LinearAlgebra.opnorm(socket(tn), tn, p) +function LinearAlgebra.opnorm(::Operator, tn::Product, p::Real) + return mapreduce(*, tensors(tn)) do tensor + opnorm(parent(tensor), p) + end^(1//p) +end + +LinearAlgebra.normalize!(tn::Product, p::Real=2) = LinearAlgebra.normalize!(socket(tn), tn, p) +function LinearAlgebra.normalize!(::Union{State,Operator}, tn::Product, p::Real) + for tensor in tensors(tn) + normalize!(tensor, p) + end + return tn +end + +overlap(a::Product, b::Product) = overlap(socket(a), a, socket(b), b) + +function overlap(::State, a::Product, ::State, b::Product) + @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" + + mapreduce(*, zip(tensors(a), tensors(b))) do (ta, tb) + dot(parent(ta), conj(parent(tb))) + end +end diff --git a/src/Tenet.jl b/src/Tenet.jl index 98aaeac14..8016f549a 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -34,6 +34,8 @@ export form export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! export evolve!, expect, overlap +include("Ansatz/Product.jl") + # reexports from EinExprs export einexpr, inds From 1f9ba62675a96c9d0c56c2435045b45e9a1a8a97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 01:02:44 -0400 Subject: [PATCH 02/19] Format code --- src/Ansatz/Product.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index fd6c07d9a..8c86f570c 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -9,7 +9,6 @@ end Ansatz(tn::Product) = tn.tn Base.copy(x::Product) = Product(copy(Ansatz(x))) - Base.similar(x::Product) = Product(similar(Ansatz(x))) Base.zero(x::Product) = Product(zero(Ansatz(x))) From 4eceaf695010d226e35d91d56f38d5c695b0b143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:42:45 -0400 Subject: [PATCH 03/19] Fix typo --- src/Ansatz/Product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 8c86f570c..0c5bde2be 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -32,7 +32,7 @@ function Product(arrays::Vector{<:AbstractMatrix}) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i + n], symbols[i]], []) + Tensor(array, [symbols[i + n], symbols[i]]) end sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) From 4ff32af1ec9207b7c449e01bf6edc94c07add5b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Nov 2024 22:36:20 +0100 Subject: [PATCH 04/19] Refactor `adapt_structure` method to support additional types --- ext/TenetAdaptExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index b3d09b0ed..1331b86a4 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -8,5 +8,6 @@ Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tens Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x)) +Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x))) end From 56f23012caa58ddc55abddd9afb7609445eabd68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Nov 2024 22:42:24 +0100 Subject: [PATCH 05/19] Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on top of recent changes --- ext/TenetReactantExt.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 22a5ec5b5..133ac4a5a 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -34,8 +34,12 @@ end function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Ansatz(tracetn, copy(Tenet.lattice(prev))) -end +# TODO try rely on generic fallback for ansatzes +function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return Tenet.Product(tracequantum) +end function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) return :($Tensor($data, $(inds(tocopy)))) @@ -58,6 +62,12 @@ function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stor return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy))))) end +# TODO try rely on generic fallback for ansatzes +function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($(Tenet.Product)($tn)) +end + # TODO try rely on generic fallback for ansatzes # function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) # tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) From 568f754f7539a83feeff8e5fb8d0c66545cd2107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Nov 2024 22:47:46 +0100 Subject: [PATCH 06/19] Refactor `ChainRules` methods on top of new types --- ext/TenetChainRulesCoreExt/frules.jl | 3 +++ ext/TenetChainRulesCoreExt/rrules.jl | 5 +++++ test/integration/ChainRules_test.jl | 6 ++++++ 3 files changed, 14 insertions(+) diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index d1f0d6355..d6bb3d43e 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -16,6 +16,9 @@ function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice) return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent()) end +# `AbstractAnsatz`-subtype constructors +ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ) + # `Base.conj` methods ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ) diff --git a/ext/TenetChainRulesCoreExt/rrules.jl b/ext/TenetChainRulesCoreExt/rrules.jl index 9c25e641d..a932dd062 100644 --- a/ext/TenetChainRulesCoreExt/rrules.jl +++ b/ext/TenetChainRulesCoreExt/rrules.jl @@ -20,6 +20,11 @@ Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback +# `AbstractAnsatz`-subtype constructors +Product_pullback(ȳ) = (NoTangent(), ȳ.tn) +Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback + # `Base.conj` methods conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ)) conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ)) diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index 44e94f9a3..ae5472abb 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -199,4 +199,10 @@ test_frule(Ansatz, tn, lattice) test_rrule(Ansatz, tn, lattice) end + + @testset "Product" begin + tn = Product([ones(2), ones(2), ones(2)]) + test_frule(Product, Ansatz(tn)) + test_rrule(Product, Ansatz(tn)) + end end From 10afdd38405cf5224d833f353087508f32521b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 12:27:36 +0100 Subject: [PATCH 07/19] Aesthetic name fix --- ext/TenetReactantExt.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 133ac4a5a..53fab1187 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -37,9 +37,10 @@ function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.Tr # TODO try rely on generic fallback for ansatzes function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return Tenet.Product(tracequantum) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return Tenet.Product(tracetn) end + function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) return :($Tensor($data, $(inds(tocopy)))) From 0c634c81264046a688959b2a7c0b46c289b8b50b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 12:28:22 +0100 Subject: [PATCH 08/19] Stop using `IdDict` on Reactant extension --- ext/TenetReactantExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 53fab1187..a3b13b8f1 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -36,7 +36,7 @@ function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.Tr return Ansatz(tracetn, copy(Tenet.lattice(prev))) # TODO try rely on generic fallback for ansatzes -function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Tenet.Product(tracetn) end From fad4e2b671e10d13abca126e6700118ce8aac573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 12:29:21 +0100 Subject: [PATCH 09/19] Refactor lattice generation in constructors of `Dense`, `Product`, `MPS`, `MPO`, `PEPS` --- src/Ansatz/Product.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 0c5bde2be..6412006d6 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -22,7 +22,8 @@ function Product(arrays::Vector{<:AbstractVector}) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) qtn = Quantum(TensorNetwork(_tensors), sitemap) - lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end @@ -37,7 +38,8 @@ function Product(arrays::Vector{<:AbstractMatrix}) sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) qtn = Quantum(TensorNetwork(_tensors), sitemap) - lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end From 71a9fc8b45f1f88ebeb7cb8b214590daecd68aa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 12:35:26 +0100 Subject: [PATCH 10/19] Reenable `Product` tests --- test/Product_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Product_test.jl b/test/Product_test.jl index 1c013c1f1..14619605d 100644 --- a/test/Product_test.jl +++ b/test/Product_test.jl @@ -1,4 +1,4 @@ -@testset_skip "Product ansatz" begin +@testset "Product ansatz" begin using LinearAlgebra # TODO test `Product` with `Scalar` socket From a22a2baa53fc018f7463dbfed14cee8633439916 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:47:41 +0100 Subject: [PATCH 11/19] fix constructors --- src/Ansatz/Product.jl | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 6412006d6..e9c9af351 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -1,6 +1,5 @@ using LinearAlgebra using Graphs -using MetaGraphsNext struct Product <: AbstractAnsatz tn::Ansatz @@ -12,34 +11,43 @@ Base.copy(x::Product) = Product(copy(Ansatz(x))) Base.similar(x::Product) = Product(similar(Ansatz(x))) Base.zero(x::Product) = Product(zero(Ansatz(x))) -function Product(arrays::Vector{<:AbstractVector}) +function Product(arrays::AbstractArray{<:AbstractVector}) n = length(arrays) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:n] - _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i]]) + symbols = map(arrays) do _ + nextindex!(gen) + end + _tensors = map(eachindex(arrays)) do i + Tensor(arrays[i], [symbols[i]]) end - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + sitemap = Dict(Site(i) => symbols[i] for i in eachindex(arrays)) qtn = Quantum(TensorNetwork(_tensors), sitemap) graph = Graph(n) - lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for i in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end -function Product(arrays::Vector{<:AbstractMatrix}) +function Product(arrays::AbstractArray{<:AbstractMatrix}) n = length(arrays) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] - _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i + n], symbols[i]]) + symbols = map(arrays) do _ + (nextindex!(gen), nextindex!(gen)) + end + _tensors = map(eachindex(arrays)) do i + Tensor(arrays[i], [symbols[i][1], symbols[i][2]]) end - sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) + sitemap = merge!( + Dict(Site(i; dual=true) => symbols[i][1] for i in eachindex(arrays)), + Dict(Site(i) => symbols[i][2] for i in eachindex(arrays)), + ) qtn = Quantum(TensorNetwork(_tensors), sitemap) graph = Graph(n) - lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for i in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end From b56264ad9a7062bb6c30ba82ee66d3fb7d5a0527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:47:55 +0100 Subject: [PATCH 12/19] remove `zeros`, `ones` (not well defined) --- src/Ansatz/Product.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index e9c9af351..e023d3c38 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -52,14 +52,6 @@ function Product(arrays::AbstractArray{<:AbstractMatrix}) return Product(ansatz) end -function Base.zeros(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product(fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) -end - -function Base.ones(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product(fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n)) -end - LinearAlgebra.norm(tn::Product, p::Real=2) = LinearAlgebra.norm(socket(tn), tn, p) function LinearAlgebra.norm(::Union{State,Operator}, tn::Product, p::Real) return mapreduce(*, tensors(tn)) do tensor From c2c65d0899609bc47ede0df604cf3a25e96edb29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:48:04 +0100 Subject: [PATCH 13/19] fix symbol import --- ext/TenetChainRulesCoreExt/frules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index d6bb3d43e..e9bc53325 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -1,4 +1,4 @@ -using Tenet: AbstractTensorNetwork, AbstractQuantum +using Tenet: AbstractTensorNetwork, AbstractQuantum, Product # `Tensor` constructor ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds) From 57153402b9f4349aa9a4f5a659bf0a26f8ac37e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:51:00 +0100 Subject: [PATCH 14/19] export `Product` --- ext/TenetChainRulesCoreExt/frules.jl | 2 +- src/Tenet.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index e9bc53325..d6bb3d43e 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -1,4 +1,4 @@ -using Tenet: AbstractTensorNetwork, AbstractQuantum, Product +using Tenet: AbstractTensorNetwork, AbstractQuantum # `Tensor` constructor ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds) diff --git a/src/Tenet.jl b/src/Tenet.jl index 8016f549a..0ca5d2dc0 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -35,6 +35,7 @@ export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed export evolve!, expect, overlap include("Ansatz/Product.jl") +export Product # reexports from EinExprs export einexpr, inds From cd66c892abd2db53e9c629d9174964d816116f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:56:11 +0100 Subject: [PATCH 15/19] fix constructors --- src/Ansatz/Product.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index e023d3c38..6c2ed9cd7 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -24,7 +24,7 @@ function Product(arrays::AbstractArray{<:AbstractVector}) sitemap = Dict(Site(i) => symbols[i] for i in eachindex(arrays)) qtn = Quantum(TensorNetwork(_tensors), sitemap) graph = Graph(n) - mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for i in enumerate(lanes(qtn))]) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) lattice = Lattice(mapping, graph) ansatz = Ansatz(qtn, lattice) return Product(ansatz) @@ -46,7 +46,7 @@ function Product(arrays::AbstractArray{<:AbstractMatrix}) ) qtn = Quantum(TensorNetwork(_tensors), sitemap) graph = Graph(n) - mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for i in enumerate(lanes(qtn))]) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) lattice = Lattice(mapping, graph) ansatz = Ansatz(qtn, lattice) return Product(ansatz) From 8d887f2d76c566676d6e8d8ce4bfc607ea05f240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 16:57:24 +0100 Subject: [PATCH 16/19] refactor tests --- test/Product_test.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/Product_test.jl b/test/Product_test.jl index 14619605d..b6fc318fc 100644 --- a/test/Product_test.jl +++ b/test/Product_test.jl @@ -14,9 +14,7 @@ end @test adjoint(qtn) isa Product @test socket(adjoint(qtn)) == State(; dual=true) - - # conversion to `Quantum` - @test Quantum(qtn) isa Quantum + @test Ansatz(qtn) isa Ansatz qtn = Product([rand(2, 2) for _ in 1:3]) @test socket(qtn) == Operator() @@ -30,4 +28,5 @@ end @test adjoint(qtn) isa Product @test socket(adjoint(qtn)) == Operator() + @test Ansatz(qtn) isa Ansatz end From bd7752fc204ce8d838f46f524d3f417e76aff7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 18:37:15 +0100 Subject: [PATCH 17/19] fix typo --- ext/TenetReactantExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index a3b13b8f1..9e29248ca 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -34,6 +34,7 @@ end function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Ansatz(tracetn, copy(Tenet.lattice(prev))) +end # TODO try rely on generic fallback for ansatzes function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) From a89e984fd94de14e4409c6bf318e4cd18d057801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 5 Nov 2024 18:39:05 +0100 Subject: [PATCH 18/19] move file --- src/{Ansatz => }/Product.jl | 0 src/Tenet.jl | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{Ansatz => }/Product.jl (100%) diff --git a/src/Ansatz/Product.jl b/src/Product.jl similarity index 100% rename from src/Ansatz/Product.jl rename to src/Product.jl diff --git a/src/Tenet.jl b/src/Tenet.jl index 0ca5d2dc0..295fffc33 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -34,7 +34,7 @@ export form export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! export evolve!, expect, overlap -include("Ansatz/Product.jl") +include("Product.jl") export Product # reexports from EinExprs From 8cc5d132e45a469acea1d9b9719a9206ecdd8a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 7 Nov 2024 09:24:16 +0100 Subject: [PATCH 19/19] Document `Product` constructor --- src/Product.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/Product.jl b/src/Product.jl index 6c2ed9cd7..baf074e76 100644 --- a/src/Product.jl +++ b/src/Product.jl @@ -1,6 +1,16 @@ using LinearAlgebra using Graphs +""" + Product <: AbstractAnsatz + +An [`Ansatz`](@ref) represented as a tensor product. + +# Constructors + +If you pass an `Abstract{<:AbstractVector}` to the constructor, it will create a [`State`](@ref). +If you pass an `Abstract{<:AbstractMatrix}` to the constructor, it will create an [`Operator`](@ref). +""" struct Product <: AbstractAnsatz tn::Ansatz end