From e35e4bf2e2b05934452aac62b01949f055957263 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 31 Oct 2024 10:52:40 +0100 Subject: [PATCH 1/4] Replace adjoint function --- src/Quantum.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/Quantum.jl b/src/Quantum.jl index f3f459d65..3eed0db3d 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -335,7 +335,21 @@ end Returns the adjoint of a [`Quantum`](@ref) Tensor Network; i.e. the conjugate Tensor Network with the inputs and outputs swapped. """ -Base.adjoint(tn::AbstractQuantum) = adjoint!(deepcopy(tn)) +function Base.adjoint(tn::AbstractQuantum) + tn = conj(tn) + + # update site information + oldsites = copy(Quantum(tn).sites) + empty!(Quantum(tn).sites) + for (site, index) in oldsites + addsite!(tn, site', index) + end + + # rename inner indices + replace!(tn, map(i -> i => Symbol(i, "'"), inds(tn; set=:virtual))) + + return tn +end function LinearAlgebra.adjoint!(tn::AbstractQuantum) conj!(tn) From a9221e98b95214e14fceaac277502fe575159f73 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 31 Oct 2024 10:52:48 +0100 Subject: [PATCH 2/4] Replace conj function --- src/TensorNetwork.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 8b8676674..557a150de 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -87,7 +87,17 @@ end Base.eltype(tn::AbstractTensorNetwork) = promote_type(eltype.(tensors(tn))...) -Base.conj(tn::AbstractTensorNetwork) = conj!(deepcopy(tn)) +""" + conj(tn::AbstractTensorNetwork) + +Return a copy of the [`AbstractTensorNetwork`](@ref) with all tensors conjugated. +""" +function Base.conj(tn::AbstractTensorNetwork) + tn = copy(tn) + replace!(tn, Pair.(tensors(tn), conj.(tensors(tn)))) + return tn +end + function Base.conj!(tn::AbstractTensorNetwork) foreach(conj!, tensors(tn)) return tn From f55d9590be421fcab0a246c8d363c6fbe4327df9 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 31 Oct 2024 12:06:23 +0100 Subject: [PATCH 3/4] Add and extend tests --- test/Quantum_test.jl | 19 +++++++++++++++++++ test/TensorNetwork_test.jl | 8 ++++++++ 2 files changed, 27 insertions(+) diff --git a/test/Quantum_test.jl b/test/Quantum_test.jl index cb15401b8..5d463d2bd 100644 --- a/test/Quantum_test.jl +++ b/test/Quantum_test.jl @@ -55,6 +55,25 @@ @test_throws ErrorException Quantum(tn, Dict(site"1" => :j)) @test_throws ErrorException Quantum(tn, Dict(site"1" => :i)) + @testset "Base.adjoint" begin + _tensors = Tensor[ + Tensor(rand(ComplexF64, 2, 4, 2), [:i, :link, :j]), Tensor(rand(ComplexF64, 2, 4, 2), [:k, :link, :l]) + ] + tn = TensorNetwork(_tensors) + qtn = Quantum(tn, Dict(site"1" => :i, site"2" => :k, site"1'" => :j, site"2'" => :l)) + + adjoint_qtn = adjoint(qtn) + + @test nsites(adjoint_qtn; set=:inputs) == nsites(adjoint_qtn; set=:outputs) == 2 + @test issetequal(sites(adjoint_qtn), [site"1", site"2", site"1'", site"2'"]) + @test socket(adjoint_qtn) == Operator() + @test inds(adjoint_qtn; at=site"1'") == :i # now the indices are flipped + @test inds(adjoint_qtn; at=site"1") == :j + @test inds(adjoint_qtn; at=site"2'") == :k + @test inds(adjoint_qtn; at=site"2") == :l + @test isapprox(tensors(adjoint_qtn), replace.(conj.(_tensors), :link => Symbol(:link, "'"))) + end + @testset "reindex!" begin @testset "manual indices" begin # mps-like tensor network diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 3a82c2acd..73fbea7ac 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -643,6 +643,14 @@ @test issetequal(inds(projvirttn), [:i, :k]) end + @testset "Base.conj" begin + tensor1 = Tensor(rand(ComplexF64, 3, 4), (:i, :j)) + tensor2 = Tensor(rand(ComplexF64, 4, 5), (:j, :k)) + complextn = TensorNetwork([tensor1, tensor2]) + + @test -imag.(tensors(complextn)) == imag.(tensors(conj(complextn))) + end + @testset "Base.conj!" begin @testset "for complex" begin tensor1 = Tensor(rand(ComplexF64, 3, 4), (:i, :j)) From 3da1a78eb6e7c8fb8651eb1965bb53a8af08dc5e Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Thu, 31 Oct 2024 15:09:00 +0100 Subject: [PATCH 4/4] Format code --- src/Quantum.jl | 22 ++++------------------ src/TensorNetwork.jl | 2 +- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/Quantum.jl b/src/Quantum.jl index 3eed0db3d..e9f98b00e 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -335,26 +335,12 @@ end Returns the adjoint of a [`Quantum`](@ref) Tensor Network; i.e. the conjugate Tensor Network with the inputs and outputs swapped. """ -function Base.adjoint(tn::AbstractQuantum) - tn = conj(tn) +Base.adjoint(tn::AbstractQuantum) = adjoint_sites!(conj(tn)) - # update site information - oldsites = copy(Quantum(tn).sites) - empty!(Quantum(tn).sites) - for (site, index) in oldsites - addsite!(tn, site', index) - end - - # rename inner indices - replace!(tn, map(i -> i => Symbol(i, "'"), inds(tn; set=:virtual))) - - return tn -end - -function LinearAlgebra.adjoint!(tn::AbstractQuantum) - conj!(tn) +LinearAlgebra.adjoint!(tn::AbstractQuantum) = adjoint_sites!(conj!(tn)) - # update site information +# update site information and rename inner indices +function adjoint_sites!(tn::AbstractQuantum) oldsites = copy(Quantum(tn).sites) empty!(Quantum(tn).sites) for (site, index) in oldsites diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 557a150de..f5f4b8fad 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -94,7 +94,7 @@ Return a copy of the [`AbstractTensorNetwork`](@ref) with all tensors conjugated """ function Base.conj(tn::AbstractTensorNetwork) tn = copy(tn) - replace!(tn, Pair.(tensors(tn), conj.(tensors(tn)))) + replace!(tn, tensors(tn) .=> conj.(tensors(tn))) return tn end