diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 984ca6868..8b8676674 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -437,6 +437,9 @@ end function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor}) tn = TensorNetwork(tn) old_tensor, new_tensor = pair + + old_tensor === new_tensor && return tn + issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match")) push!(tn, new_tensor) diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index c30b97b31..3a82c2acd 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -459,60 +459,75 @@ end @testset "replace tensors" begin - t_ij = Tensor(zeros(2, 2), (:i, :j)) - t_ik = Tensor(zeros(2, 2), (:i, :k)) - t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) - t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) - - old_tensor = t_lm + @testset "Basic replacement" begin + t_ij = Tensor(zeros(2, 2), (:i, :j)) + t_ik = Tensor(zeros(2, 2), (:i, :k)) + t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) + t_lm = Tensor(zeros(2, 2), (:l, :m)) + tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) + + old_tensor = t_lm + + @test_throws ArgumentError begin + new_tensor = Tensor(rand(2, 2), (:a, :b)) + replace!(tn, old_tensor => new_tensor) + end - @test_throws ArgumentError begin - new_tensor = Tensor(rand(2, 2), (:a, :b)) + new_tensor = Tensor(rand(2, 2), (:l, :m)) replace!(tn, old_tensor => new_tensor) + + @test new_tensor === only(filter(t -> issetequal(inds(t), [:l, :m]), tensors(tn))) + + # Check if connections are maintained + for ind in inds(new_tensor) + tensors_with_ind = tn.indexmap[ind] + @test new_tensor ∈ tensors_with_ind + @test !(old_tensor ∈ tensors_with_ind) + end end - new_tensor = Tensor(rand(2, 2), (:l, :m)) - replace!(tn, old_tensor => new_tensor) + @testset "TensorNetwork with tensors of equal indices" begin + A = Tensor(rand(2, 2), (:u, :w)) + B = Tensor(rand(2, 2), (:u, :w)) + tn = TensorNetwork([A, B]) - @test new_tensor === only(filter(t -> issetequal(inds(t), [:l, :m]), tensors(tn))) + new_tensor = Tensor(rand(2, 2), (:u, :w)) - # Check if connections are maintained - # for label in inds(new_tensor) - # index = tn.inds[label] - # @test new_tensor in index.links - # @test !(old_tensor in index.links) - # end + replace!(tn, B => new_tensor) + @test A ∈ tensors(tn) + @test new_tensor ∈ tensors(tn) - # New tensor network with two tensors with the same inds - # A = Tensor(rand(2, 2), (:u, :w)) - # B = Tensor(rand(2, 2), (:u, :w)) - # tn = TensorNetwork([A, B]) + tn = TensorNetwork([A, B]) + replace!(tn, A => new_tensor) - # new_tensor = Tensor(rand(2, 2), (:u, :w)) + @test issetequal(tensors(tn), [new_tensor, B]) + end - # replace!(tn, B => new_tensor) - # @test A === tensors(tn)[1] - # @test new_tensor === tensors(tn)[2] + @testset "Sequence of replacements" begin + A = Tensor(zeros(2, 2), (:i, :j)) + B = Tensor(zeros(2, 2), (:j, :k)) + C = Tensor(zeros(2, 2), (:k, :l)) + tn = TensorNetwork([A, B, C]) - # tn = TensorNetwork([A, B]) - # replace!(tn, A => new_tensor) + @test_throws ArgumentError replace!(tn, A => B, B => C, C => A) - # @test issetequal(tensors(tn), [new_tensor, B]) + new_tensor = Tensor(rand(2, 2), (:i, :j)) + new_tensor2 = Tensor(ones(2, 2), (:i, :j)) - # # Test chain of replacements - # A = Tensor(zeros(2, 2), (:i, :j)) - # B = Tensor(zeros(2, 2), (:j, :k)) - # C = Tensor(zeros(2, 2), (:k, :l)) - # tn = TensorNetwork([A, B, C]) + replace!(tn, A => new_tensor, new_tensor => new_tensor2) + @test issetequal(tensors(tn), [new_tensor2, B, C]) + end - # @test_throws ArgumentError replace!(tn, A => B, B => C, C => A) + @testset "Replace with itself" begin + A = Tensor(rand(2, 2), (:i, :j)) + B = Tensor(rand(2, 2), (:j, :k)) + C = Tensor(rand(2, 2), (:k, :l)) + tn = TensorNetwork([A, B, C]) - # new_tensor = Tensor(rand(2, 2), (:i, :j)) - # new_tensor2 = Tensor(ones(2, 2), (:i, :j)) + replace!(tn, A => A) - # replace!(tn, A => new_tensor, new_tensor => new_tensor2) - # @test issetequal(tensors(tn), [new_tensor2, B, C]) + @test issetequal(tensors(tn), [A, B, C]) + end end @testset "replace tensors by tensor network" begin