From 2abebc6a74704ff2b3d701a2f0f0b7c8b8211278 Mon Sep 17 00:00:00 2001 From: Jutho Date: Fri, 1 Nov 2024 11:20:29 +0100 Subject: [PATCH 1/2] better gauge warning tolerance for factorisation rrules --- .../factorizations.jl | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 4dcf481b..d4dc66f7 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -184,8 +184,7 @@ end # function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector, Vd::AbstractMatrix, ΔU, ΔS, ΔVd; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(eltype(S))^(3 / 4)) + tol::Real=default_pullback_gaugetol(S)) # Basic size checks and determination m, n = size(U, 1), size(Vd, 2) @@ -214,8 +213,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector Vp = view(Vd, 1:p, :)' Sp = view(S, 1:p) - # tolerance and rank - tol = atol > 0 ? atol : rtol * S[1, 1] + # rank r = findlast(>=(tol), S) # compute antihermitian part of projection of ΔU and ΔV onto U and V @@ -302,16 +300,12 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector end function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4)) + tol::Real=default_pullback_gaugetol(D)) # Basic size checks and determination n = LinearAlgebra.checksquare(V) n == length(D) || throw(DimensionMismatch()) - # tolerance and rank - tol = atol > 0 ? atol : rtol * maximum(abs, D) - if !(ΔV isa AbstractZero) VdΔV = V' * ΔV @@ -345,16 +339,12 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix end function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4)) + tol::Real=default_pullback_gaugetol(D)) # Basic size checks and determination n = LinearAlgebra.checksquare(V) n == length(D) || throw(DimensionMismatch()) - # tolerance and rank - tol = atol > 0 ? atol : rtol * maximum(abs, D) - if !(ΔV isa AbstractZero) VdΔV = V' * ΔV aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) @@ -379,10 +369,8 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri end function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4)) + tol::Real=default_pullback_gaugetol(R)) Rd = view(R, diagind(R)) - tol = atol > 0 ? atol : rtol * maximum(abs, Rd) p = findlast(>=(tol) ∘ abs, Rd) m, n = size(R) @@ -432,10 +420,8 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, end function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ; - atol::Real=0, - rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4)) + tol::Real=default_pullback_gaugetol(L)) Ld = view(L, diagind(L)) - tol = atol > 0 ? atol : rtol * maximum(abs, Ld) p = findlast(>=(tol) ∘ abs, Ld) m, n = size(L) @@ -483,3 +469,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ldiv!(LowerTriangular(L11)', ΔA1) return ΔA end + +function default_pullback_gaugetol(a) + n = norm(a, Inf) + return eps(eltype(n))^(3 / 4) * max(n, one(n)) +end From c2a44bdbd69e0ca105d9b1c66ffe0a8dbe3eb97b Mon Sep 17 00:00:00 2001 From: Jutho Date: Fri, 1 Nov 2024 11:20:56 +0100 Subject: [PATCH 2/2] fix tensoralloc to support other allocators --- src/tensors/tensoroperations.jl | 3 ++- test/planar.jl | 42 ++++++++++++++++++++------------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 07a96d1c..02b7c902 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -13,7 +13,8 @@ function TO.tensoralloc(::Type{TT}, A = storagetype(TT) dim = fusionblockstructure(structure).totaldim data = TO.tensoralloc(A, dim, istemp, allocator) - return TT(data, structure) + # return TT(data, structure) + return TensorMap{T}(data, structure) end function TO.tensorfree!(t::TensorMap, allocator=TO.DefaultAllocator()) diff --git a/test/planar.jl b/test/planar.jl index 42281f92..0487404f 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -132,9 +132,14 @@ end GL′ = force_planar(GL) GR′ = force_planar(GR) - @tensor y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * O[2 -2; 3 5] * GR[4 5; -3] - @planar y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * O′[2 -2; 3 5] * GR′[4 5; -3] - @test force_planar(y) ≈ y′ + for alloc in + (TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator()) + @tensor allocator = alloc y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * + O[2 -2; 3 5] * GR[4 5; -3] + @planar allocator = alloc y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * + O′[2 -2; 3 5] * GR′[4 5; -3] + @test force_planar(y) ≈ y′ + end # ∂AC2 # ------- @@ -193,21 +198,24 @@ end ρ′ = force_planar(ρ) h′ = force_planar(h) - @tensor begin - C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) * - (u[8 5; 15 6] * w[6 7; 19])) * - (conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) * - ((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) * - w[16 15; 18]) * conj(w[16 17; 21])) - end - @planar begin - C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) * - (u′[8 5; 15 6] * w′[6 7; 19])) * - (conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) * - ((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) * - w′[16 15; 18]) * conj(w′[16 17; 21])) + for alloc in + (TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator()) + @tensor allocator = alloc begin + C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) * + (u[8 5; 15 6] * w[6 7; 19])) * + (conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) * + ((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) * + w[16 15; 18]) * conj(w[16 17; 21])) + end + @planar allocator = alloc begin + C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) * + (u′[8 5; 15 6] * w′[6 7; 19])) * + (conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) * + ((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) * + w′[16 15; 18]) * conj(w′[16 17; 21])) + end + @test C ≈ C′ end - @test C ≈ C′ end @testset "Issue 93" begin