Skip to content

Commit

Permalink
various changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 4, 2020
1 parent a2bdb1c commit 7f5d989
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 54 deletions.
94 changes: 56 additions & 38 deletions src/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Grassmann
# with thus W'*Z = 0, W'*U = 0

using TensorKit
using TensorKit: similarstoragetype, fusiontreetype, SectorDict
using TensorKit: similarstoragetype, SectorDict
using ..TensorKitManifolds: projecthermitian!, projectantihermitian!,
projectisometric!, projectcomplement!, PolarNewton
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!
Expand All @@ -28,22 +28,10 @@ mutable struct GrassmannTangent{T<:AbstractTensorMap,
G = sectortype(W)
M = similarstoragetype(W, TT)
Mr = similarstoragetype(W, real(TT))
if G === Trivial
TU = TensorMap{S,N₁,1,G,M,Nothing,Nothing}
TS = TensorMap{S,1,1,G,Mr,Nothing,Nothing}
TV = TensorMap{S,1,N₂,G,M,Nothing,Nothing}
return new{T,TU,TS,TV}(W, Z, nothing, nothing, nothing)
else
F = fusiontreetype(G, 1)
F1 = fusiontreetype(G, N₁)
F2 = fusiontreetype(G, N₂)
D = SectorDict{G,M}
Dr = SectorDict{G,Mr}
TU = TensorMap{S,N₁,1,G,D,F1,F}
TS = TensorMap{S,1,1,G,Dr,F,F}
TV = TensorMap{S,1,N₂,G,D,F,F2}
return new{T,TU,TS,TV}(W, Z, nothing, nothing, nothing)
end
TU = tensormaptype(S, N₁, 1, M)
TS = tensormaptype(S, 1, 1, Mr)
TV = tensormaptype(S, 1, N₂, M)
return new{T,TU,TS,TV}(W, Z, nothing, nothing, nothing)
end
end
function Base.copy::GrassmannTangent)
Expand Down Expand Up @@ -99,14 +87,20 @@ Base.zero(Δ::GrassmannTangent) = GrassmannTangent(Δ.W, zero(Δ.Z))
function TensorKit.rmul!::GrassmannTangent, α::Number)
rmul!.Z, α)
if Base.getfield(Δ, :S) !== nothing
rmul!.S, α)
if sign(α) != 1
rmul!.V, sign(α))
end
rmul!.S, abs(α))
end
return Δ
end
function TensorKit.lmul!::Number, Δ::GrassmannTangent)
lmul!(α, Δ.Z)
if Base.getfield(Δ, :S) !== nothing
lmul!(α, Δ.S)
if sign(α) != 1
lmul!(sign(α), Δ.U)
end
lmul!(abs(α), Δ.S)
end
return Δ
end
Expand Down Expand Up @@ -141,14 +135,36 @@ function project!(X::AbstractTensorMap, W::AbstractTensorMap; metric = :euclidea
Z = projectcomplement!(Z, W)
return GrassmannTangent(W, Z)
end
project(X, W; metric = :euclidean) = project!(copy(X), W; metric=metric)

"""
Grassmann.project(X::AbstractTensorMap, W::AbstractTensorMap; metric = :euclidean)
Project X onto the Grassmann tangent space at the base point `W`, which is assumed to be
isometric, i.e. `W'*W ≈ id(domain(W))`. The resulting tensor `Z` in the tangent space of
`W` is given by `Z = X - W * (W'*X)` and satisfies `W'*Z = 0`.
"""
project(X::AbstractTensorMap, W::AbstractTensorMap; metric = :euclidean) =
project!(copy(X), W; metric=metric)

function inner(W::AbstractTensorMap, Δ₁::GrassmannTangent, Δ₂::GrassmannTangent;
metric = :euclidean)
@assert metric == :euclidean
Δ₁ === Δ₂ ? norm(Δ₁)^2 : real(dot(Δ₁, Δ₂))
end

"""
retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = nothing)
Retract isometry `W == base(Δ)` within the Grassmann manifold using tangent vector `Δ.Z`.
If the singular value decomposition of `Z` is given by `U * S * V`, then the resulting
isometry is
`W′ = W * V' * cos(α*S) * V + U * sin(α * S) * V`
while the local tangent vector along the retraction curve is
`Z′ = - W * V' * sin(α*S) * S * V + U * cos(α * S) * S * V'`.
"""
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = nothing)
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
U, S, V = Δ.U, Δ.S, Δ.V
Expand All @@ -162,39 +178,40 @@ function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α; alg = nothing)
end

"""
invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = nothing)
Grassmann.invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = nothing)
Return the Grassmann tangent Z and unitary Y such that `retract(Wold, Z, 1) * Y ≈ Wnew`.
Return the Grassmann tangent `Z` and unitary `Y` such that `retract(Wold, Z, 1) * Y ≈ Wnew`.
This is done by solving the equation `Wold * V * cos(S) * V' + U * sin(S) * V' = Wnew * Y`
for the isometries V, U, and Y, and the diagonal matrix S, and returning Z, Y, where
`Z = U * S * V'`.
This is done by solving the equation `Wold * V' * cos(S) * V + U * sin(S) * V = Wnew * Y'`
for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning
`Z = U * S * V` and `Y`.
"""
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg = nothing)
space(Wold) == space(Wnew) || throw(SectorMismatch())
WodWn = Wold' * Wnew
res = Wnew - Wold * WodWn
V, cS, Xd = tsvd!(WodWn)
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
Wneworth = Wnew - Wold * WodWn
Vd, cS, VY = tsvd!(WodWn)
Scmplx = acos(cS)
# acos always returns a complex TensorMap. We cast back to real if possible.
S = eltype(WodWn) <: Real ? real(Scmplx) : Scmplx
u, s, vd = tsvd!(res * Xd' * sin(S))
Z = Grassmann.GrassmannTangent(Wold, u * vd * S * V')
Y = V * Xd
UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form
U = projectisometric!(UsS; alg = Polar())
Y = Vd*VY
V = Vd'
Z = Grassmann.GrassmannTangent(Wold, U * S * V)
return Z, Y
end

"""
matchgauge(W::AbstractTensorMap, V::AbstractTensorMap)
relativegauge(W::AbstractTensorMap, V::AbstractTensorMap)
Return the unitary Y such that V*Y and W are "in the same Grassmann gauge", in the sense
that they can be connected by a Grassmann retraction.
Return the unitary Y such that V*Y and W are "in the same Grassmann gauge" (technical term
from fibre bundles: in the same section), such that they can be related by a Grassmann
retraction.
"""
function matchgauge(W::AbstractTensorMap, V::AbstractTensorMap)
function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap)
space(W) == space(V) || throw(SectorMismatch())
WdV = W' * V
u, s, v = tsvd!(WdV)
return v' * u'
return projectisometric!(V'*W; alg = Polar())
end

function transport!::GrassmannTangent, W::AbstractTensorMap, Δ::GrassmannTangent, α, W′;
Expand Down Expand Up @@ -227,6 +244,7 @@ function _sincosSV(α::Real, S::AbstractTensorMap, V::AbstractTensorMap)
Threads.@threads for j = 1:size(bV,2)
@simd for i = 1:size(bV, 1)
sS, cS = sincos*bS[i,i])
# TODO: we are computing sin and cos above within the loop over j, while it is independent; moving it out the loop requires extra storage though.
bsSV[i,j] = sS*bV[i,j]
bcSV[i,j] = cS*bV[i,j]
end
Expand Down
2 changes: 1 addition & 1 deletion src/stiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import ..TensorKitManifolds: base, checkbase,
inner, retract, transport, transport!

# special type to store tangent vectors using A and Z = Q*R,
mutable struct StiefelTangent{T<:AbstractTensorMap, TA<:AbstractTensorMap}
struct StiefelTangent{T<:AbstractTensorMap, TA<:AbstractTensorMap}
W::T
A::TA
Z::T
Expand Down
15 changes: 2 additions & 13 deletions src/unitary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import TensorKit: similarstoragetype, SectorDict
using ..TensorKitManifolds: projectantihermitian!, projectisometric!, PolarNewton
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!

mutable struct UnitaryTangent{T<:AbstractTensorMap, TA<:AbstractTensorMap}
struct UnitaryTangent{T<:AbstractTensorMap, TA<:AbstractTensorMap}
W::T
A::TA
function UnitaryTangent(W::AbstractTensorMap{S,N₁,N₂},
Expand All @@ -24,17 +24,6 @@ base(Δ::UnitaryTangent) = Δ.W
checkbase(Δ₁::UnitaryTangent, Δ₂::UnitaryTangent) = Δ₁.W == Δ₂.W ? Δ₁.W :
throw(ArgumentError("tangent vectors with different base points"))

function Base.getproperty::UnitaryTangent, sym::Symbol)
if sym (:W, :A)
return Base.getfield(Δ, sym)
else
error("type UnitaryTangent has no field $sym")
end
end
function Base.setproperty!::UnitaryTangent, sym::Symbol, v)
error("type UnitaryTangent does not allow to change its fields")
end

# Basic vector space behaviour
Base.:+(Δ₁::UnitaryTangent, Δ₂::UnitaryTangent) =
UnitaryTangent(checkbase(Δ₁,Δ₂), Δ₁.A + Δ₂.A)
Expand Down Expand Up @@ -92,7 +81,7 @@ project(X, W; metric = :euclidean) = project!(copy(X), W; metric = :euclidean)
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg = nothing)
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
E = exp*Δ.A)
W′ = projectisometric!(W*E; alg = PolarNewton())
W′ = projectisometric!(W*E; alg = SDD())
A′ = Δ.A
return W′, UnitaryTangent(W′, A′)
end
Expand Down
8 changes: 6 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ const α = 0.75
Wend = TensorMap(randhaar, T, codomain(W), domain(W))
Δ3, V = Grassmann.invretract(W, Wend)
@test Wend retract(W, Δ3, 1)[1] * V
U = Grassmann.matchgauge(W, Wend)
U = Grassmann.relativegauge(W, Wend)
V2 = Grassmann.invretract(W, Wend * U)[2]
@test V2 one(V2)
end
end

@testset "Stiefel with space $V" for V in spaces
for T in (Float64, ComplexF64)
W, = leftorth(TensorMap(randn, T, V*V*V, V*V); alg = Polar())
W = TensorMap(randhaar, T, V*V*V, V*V)
X = TensorMap(randn, T, space(W))
Y = TensorMap(randn, T, space(W))
Δ = @inferred Stiefel.project_euclidean(X, W)
Expand Down Expand Up @@ -110,6 +110,10 @@ end
@test Stiefel.inner_euclidean(W2, Ξ2, Θ2) Stiefel.inner_euclidean(W, Ξ, Θ)
@test Stiefel.inner_canonical(W2, Δ2, Θ2) Stiefel.inner_canonical(W, Δ, Θ)
@test Stiefel.inner_canonical(W2, Ξ2, Θ2) Stiefel.inner_canonical(W, Ξ, Θ)

W3 = projectisometric!(W + 1e-1 * TensorMap(rand, T, codomain(W), domain(W)))
Δ3 = Stiefel.invretract(W, W3)
@test W3 retract(W, Δ3, 1)[1]
end
end

Expand Down

0 comments on commit 7f5d989

Please sign in to comment.