Skip to content

Commit

Permalink
Add support for in-place interface
Browse files Browse the repository at this point in the history
The wrapped vectors or operators can now be StaticArrays
  • Loading branch information
goerz committed Jul 26, 2024
1 parent ad2e8ed commit d02023b
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 46 deletions.
26 changes: 14 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,31 @@ using Documenter
using DocumenterInterLinks
using Pkg

DocMeta.setdocmeta!(
QuantumGradientGenerators,
:DocTestSetup,
:(using QuantumGradientGenerators);
recursive=true
)

PROJECT_TOML = Pkg.TOML.parsefile(joinpath(@__DIR__, "..", "Project.toml"))
VERSION = PROJECT_TOML["version"]
NAME = PROJECT_TOML["name"]
AUTHORS = join(PROJECT_TOML["authors"], ", ") * " and contributors"
GITHUB = "https://github.com/JuliaQuantumControl/QuantumGradientGenerators.jl"

DEV_OR_STABLE = "stable/"
if endswith(VERSION, "dev")
DEV_OR_STABLE = "dev/"
end

links = InterLinks(
"Julia" => "https://docs.julialang.org/en/v1/",
"QuantumPropagators" => "https://juliaquantumcontrol.github.io/QuantumPropagators.jl/$DEV_OR_STABLE",
"QuantumControl" => "https://juliaquantumcontrol.github.io/QuantumControl.jl/$DEV_OR_STABLE",
)

println("Starting makedocs")

makedocs(;
plugins=[links],
authors=AUTHORS,
sitename="QuantumGradientGenerators.jl",
modules=[QuantumGradientGenerators],
repo="https://github.com/JuliaQuantumControl/QuantumGradientGenerators.jl/blob/{commit}{path}#{line}",
doctest=false,
format=Documenter.HTML(;
prettyurls=true,
canonical="https://juliaquantumcontrol.github.io/QuantumGradientGenerators.jl",
Expand All @@ -41,7 +46,4 @@ makedocs(;

println("Finished makedocs")

deploydocs(;
repo="github.com/JuliaQuantumControl/QuantumGradientGenerators.jl",
devbranch="master"
)
deploydocs(; repo="github.com/JuliaQuantumControl/QuantumGradientGenerators.jl",)
30 changes: 22 additions & 8 deletions src/grad_vector.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import QuantumControlBase.QuantumPropagators: _exp_prop_convert_state
import QuantumControlBase.QuantumPropagators.Interfaces: supports_inplace


@doc raw"""Extended state-vector for the dynamic gradient.
Expand Down Expand Up @@ -34,17 +35,16 @@ e^{-i G̃ dt} \begin{pmatrix} 0 \\ \vdots \\ 0 \\ |Ψ⟩ \end{pmatrix}
e^{-i Ĥ dt} |Ψ⟩
\end{pmatrix}.
```
Upon initialization, ``|Ψ̃₁⟩…|Ψ̃ₙ⟩`` are zero.
"""
struct GradVector{num_controls,T}
state::T
grad_states::Vector{T}
end

function GradVector::T, num_controls::Int64) where {T}
grad_states = [similar(Ψ) for _ = 1:num_controls]
for i = 1:num_controls
fill!(grad_states[i], 0.0)
end
grad_states = [zero(Ψ) for _ = 1:num_controls]
GradVector{num_controls,T}(copy(Ψ), grad_states)
end

Expand All @@ -55,18 +55,30 @@ end
resetgradvec!(Ψ̃::GradVector)
```
zeroes out `Ψ̃.grad_states` but leaves `Ψ̃.state` unaffected.
zeroes out `Ψ̃.grad_states` but leaves `Ψ̃.state` unaffected. This is possible
whether or not Ψ̃ supports in-place operations
([`QuantumPropagators.Interfaces.supports_inplace`](@ref))
```julia
resetgradvec!(Ψ̃::GradVector, Ψ)
```
additionally sets `Ψ̃.state` to `Ψ`.
additionally sets `Ψ̃.state` to `Ψ`, which requires that `Ψ̃.state` supports
in-place operations.
Returns `Ψ̃`.
"""
function resetgradvec!(Ψ̃::GradVector)
for i = 1:length(Ψ̃.grad_states)
fill!(Ψ̃.grad_states[i], 0.0)
if supports_inplace(Ψ̃)
for i in eachindex(Ψ̃.grad_states)
fill!(Ψ̃.grad_states[i], 0.0)
end
else
for i in eachindex(Ψ̃.grad_states)
Ψ̃.grad_states[i] = zero(Ψ̃.state)
end
end
return Ψ̃
end

function resetgradvec!(Ψ̃::GradVector{num_controls,T}, Ψ::T) where {num_controls,T}
Expand All @@ -76,3 +88,5 @@ end


_exp_prop_convert_state(::GradVector) = Vector{ComplexF64}

supports_inplace(Ψ̃::GradVector) = supports_inplace(Ψ̃.state)
3 changes: 3 additions & 0 deletions src/gradgen_operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Random: GLOBAL_RNG
import QuantumControlBase.QuantumPropagators: _exp_prop_convert_operator
import QuantumControlBase.QuantumPropagators.Controls: get_controls
import QuantumControlBase.QuantumPropagators.SpectralRange: random_state
import QuantumControlBase.QuantumPropagators.Interfaces: supports_inplace


"""Static generator for the dynamic gradient.
Expand Down Expand Up @@ -38,3 +39,5 @@ end


_exp_prop_convert_operator(::GradgenOperator) = Matrix{ComplexF64}

supports_inplace(::GradgenOperator) = true
61 changes: 35 additions & 26 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ function Base.copyto!(dest::GradVector, src::GradVector)
end


function Base.copy::GradVector)
Φ = GradVector.state, length.grad_states))
for i = 1:length.grad_states)
copyto!.grad_states[i], Ψ.grad_states[i])
end
return Φ
function Base.copy::GradVector{num_controls,T}) where {num_controls,T}
return GradVector{num_controls,T}(copy.state), [copy(ϕ) for ϕ in Ψ.grad_states])
end


Expand Down Expand Up @@ -108,23 +104,30 @@ function Base.fill!(Ψ::GradVector, v)
end


function -::GradVector, Φ::GradVector)
res = copy(Ψ)
LinearAlgebra.axpy!(-1, Φ.state, res.state)
for i = 1:length.grad_states)
LinearAlgebra.axpy!(-1, Φ.grad_states[i], res.grad_states[i])
end
return res
function Base.zero::GradVector{num_controls,T}) where {num_controls,T}
return GradVector{num_controls,T}(zero.state), [zero(ϕ) for ϕ Ψ.grad_states])
end


function +::GradVector, Φ::GradVector)
res = copy(Ψ)
LinearAlgebra.axpy!(1, Φ.state, res.state)
for i = 1:length.grad_states)
LinearAlgebra.axpy!(1, Φ.grad_states[i], res.grad_states[i])
end
return res
function -(
Ψ::GradVector{num_controls,T},
Φ::GradVector{num_controls,T}
) where {num_controls,T}
return GradVector{num_controls,T}(
Ψ.state - Φ.state,
[a - b for (a, b) in zip.grad_states, Φ.grad_states)]
)
end


function +(
Ψ::GradVector{num_controls,T},
Φ::GradVector{num_controls,T}
) where {num_controls,T}
return GradVector{num_controls,T}(
Ψ.state + Φ.state,
[a + b for (a, b) in zip.grad_states, Φ.grad_states)]
)
end


Expand All @@ -149,8 +152,12 @@ function *(
G::GradgenOperator{num_controls,GT,CGT},
Ψ::GradVector{num_controls,ST}
) where {num_controls,GT,CGT,ST}
Φ = similar(Ψ)
return LinearAlgebra.mul!(Φ, G, Ψ)
state = G.G * Ψ.state
grad_states = [G.G * ϕ for ϕ in Ψ.grad_states]
for (i, Hₙ) in enumerate(G.control_deriv_ops)
grad_states[i] += Hₙ * Ψ.state
end
return GradVector{num_controls,ST}(state, grad_states)
end


Expand Down Expand Up @@ -211,16 +218,18 @@ function Base.convert(::Type{Vector{ComplexF64}}, gradvec::GradVector)
end


function Base.convert(::Type{GradVector{num_controls,T}}, vec::T) where {num_controls,T}
function Base.convert(
::Type{GradVector{num_controls,T}},
vec::AbstractVector
) where {num_controls,T}
L = num_controls
N = length(vec) ÷ (L + 1) # dimension of state
@assert length(vec) == (L + 1) * N
grad_states = [vec[(i-1)*N+1:i*N] for i = 1:L]
state = vec[L*N+1:(L+1)*N]
grad_states = [convert(T, vec[(i-1)*N+1:i*N]) for i = 1:L]
state = convert(T, vec[L*N+1:(L+1)*N])
return GradVector{num_controls,T}(state, grad_states)
end


function Base.Array{T}(G::GradgenOperator) where {T}
N, M = size(G.G)
L = length(G.control_deriv_ops)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
GRAPE = "6b52fcaf-80fe-489a-93e9-9f92080510be"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand All @@ -12,6 +13,7 @@ QuantumControlTestUtils = "d3fd27c9-1dfb-4e67-b0c0-90d0d87a1e48"
QuantumGradientGenerators = "a563f35e-61db-434d-8c01-8b9e3ccdfd85"
QuantumPropagators = "7bf12567-5742-4b91-a078-644e72a65fc1"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
35 changes: 35 additions & 0 deletions test/test_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using QuantumControlTestUtils.RandomObjects: random_matrix, random_state_vector
using QuantumControlBase: check_generator
using QuantumPropagators.Interfaces: check_state
using QuantumGradientGenerators: GradGenerator, GradVector
using StaticArrays: SVector, SMatrix
using LinearAlgebra: norm


Expand All @@ -20,6 +21,18 @@ using LinearAlgebra: norm
end


@testset "GradVector Interface (Static)" begin

N = 10
Ψ = SVector{N,ComplexF64}(random_state_vector(N))
Ψ̃ = GradVector(Ψ, 2)
@test check_state(Ψ̃)

@test norm(2.2 * Ψ̃ - Ψ̃ * 2.2) < 1e-14

end


@testset "GradGenerator Interface" begin

N = 10
Expand All @@ -40,3 +53,25 @@ end
@test check_generator(G̃_of_t; state=Ψ̃, tlist, for_gradient_optimization=false)

end


@testset "GradGenerator Interface (Static)" begin

N = 10
Ĥ₀ = SMatrix{N,N,ComplexF64}(random_matrix(N, hermitian=true))
Ĥ₁ = SMatrix{N,N,ComplexF64}(random_matrix(N, hermitian=true))
Ĥ₂ = SMatrix{N,N,ComplexF64}(random_matrix(N, hermitian=true))
ϵ₁(t) = 1.0
ϵ₂(t) = 1.0
Ĥ_of_t = hamiltonian(Ĥ₀, (Ĥ₁, ϵ₁), (Ĥ₂, ϵ₂))

tlist = collect(range(0, 10; length=101))

G̃_of_t = GradGenerator(Ĥ_of_t)

Ψ = SVector{N,ComplexF64}(random_state_vector(N))
Ψ̃ = GradVector(Ψ, length(get_controls(G̃_of_t)))

@test check_generator(G̃_of_t; state=Ψ̃, tlist, for_gradient_optimization=false)

end

0 comments on commit d02023b

Please sign in to comment.