Skip to content

Commit

Permalink
Merge pull request #169 from Jutho/jh/tensorstructure
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
Jutho authored Nov 1, 2024
2 parents cddfaa6 + c2a44bd commit a012324
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 38 deletions.
31 changes: 11 additions & 20 deletions ext/TensorKitChainRulesCoreExt/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ end
#
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(eltype(S))^(3 / 4))
tol::Real=default_pullback_gaugetol(S))

# Basic size checks and determination
m, n = size(U, 1), size(Vd, 2)
Expand Down Expand Up @@ -214,8 +213,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
Vp = view(Vd, 1:p, :)'
Sp = view(S, 1:p)

# tolerance and rank
tol = atol > 0 ? atol : rtol * S[1, 1]
# rank
r = findlast(>=(tol), S)

# compute antihermitian part of projection of ΔU and ΔV onto U and V
Expand Down Expand Up @@ -302,16 +300,12 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
end

function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
tol::Real=default_pullback_gaugetol(D))

# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())

# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)

if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV

Expand Down Expand Up @@ -345,16 +339,12 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix
end

function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
tol::Real=default_pullback_gaugetol(D))

# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())

# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)

if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
Expand All @@ -379,10 +369,8 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri
end

function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))
tol::Real=default_pullback_gaugetol(R))
Rd = view(R, diagind(R))
tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
p = findlast(>=(tol) abs, Rd)
m, n = size(R)

Expand Down Expand Up @@ -432,10 +420,8 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
end

function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))
tol::Real=default_pullback_gaugetol(L))
Ld = view(L, diagind(L))
tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
p = findlast(>=(tol) abs, Ld)
m, n = size(L)

Expand Down Expand Up @@ -483,3 +469,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
ldiv!(LowerTriangular(L11)', ΔA1)
return ΔA
end

function default_pullback_gaugetol(a)
n = norm(a, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end
3 changes: 2 additions & 1 deletion src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ function TO.tensoralloc(::Type{TT},
A = storagetype(TT)
dim = fusionblockstructure(structure).totaldim
data = TO.tensoralloc(A, dim, istemp, allocator)
return TT(data, structure)
# return TT(data, structure)
return TensorMap{T}(data, structure)
end

function TO.tensorfree!(t::TensorMap, allocator=TO.DefaultAllocator())
Expand Down
42 changes: 25 additions & 17 deletions test/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,14 @@ end
GL′ = force_planar(GL)
GR′ = force_planar(GR)

@tensor y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * O[2 -2; 3 5] * GR[4 5; -3]
@planar y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * O′[2 -2; 3 5] * GR′[4 5; -3]
@test force_planar(y) y′
for alloc in
(TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator())
@tensor allocator = alloc y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] *
O[2 -2; 3 5] * GR[4 5; -3]
@planar allocator = alloc y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] *
O′[2 -2; 3 5] * GR′[4 5; -3]
@test force_planar(y) y′
end

# ∂AC2
# -------
Expand Down Expand Up @@ -193,21 +198,24 @@ end
ρ′ = force_planar(ρ)
h′ = force_planar(h)

@tensor begin
C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) *
(u[8 5; 15 6] * w[6 7; 19])) *
(conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) *
((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) *
w[16 15; 18]) * conj(w[16 17; 21]))
end
@planar begin
C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) *
(u′[8 5; 15 6] * w′[6 7; 19])) *
(conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) *
((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) *
w′[16 15; 18]) * conj(w′[16 17; 21]))
for alloc in
(TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator())
@tensor allocator = alloc begin
C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) *
(u[8 5; 15 6] * w[6 7; 19])) *
(conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) *
((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) *
w[16 15; 18]) * conj(w[16 17; 21]))
end
@planar allocator = alloc begin
C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) *
(u′[8 5; 15 6] * w′[6 7; 19])) *
(conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) *
((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) *
w′[16 15; 18]) * conj(w′[16 17; 21]))
end
@test C C′
end
@test C C′
end

@testset "Issue 93" begin
Expand Down

0 comments on commit a012324

Please sign in to comment.