Skip to content

Commit

Permalink
Make chi functions not-in-place
Browse files Browse the repository at this point in the history
The `chi!` functions previously used by GRAPE and Krotov are now
simply `chi` do not act in-place. This is more general and easier to
implement for the user, as it allows to use immutable structs for states

Note that in extreme performance-critical situations, one could still
construct the χ-states in-place via a closure or functor.

Both chi and J_T can now have an optional keyword argument `tau`
(instead of the previous improperly implemented and unicode `τ`).
Whether or not `tau` should be passed to these functions is
automatically detected.
  • Loading branch information
goerz committed Jul 17, 2024
1 parent fefb799 commit a8ddddf
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 47 deletions.
89 changes: 50 additions & 39 deletions src/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using QuantumControlBase.QuantumPropagators.Generators: Operator
using QuantumControlBase.QuantumPropagators.Controls: discretize, evaluate
using QuantumControlBase.QuantumPropagators: prop_step!, reinit_prop!, propagate
using QuantumControlBase.QuantumPropagators.Storage: write_to_storage!, get_from_storage!
using QuantumControlBase.QuantumPropagators.Storage:
write_to_storage!, get_from_storage!, get_from_storage
using QuantumControlBase: make_chi, set_atexit_save_optimization
using QuantumControlBase: taus!
using QuantumControlBase: @threadsif, Trajectory
using LinearAlgebra
using Printf
Expand All @@ -24,9 +26,12 @@ with explicit keyword arguments to `optimize`.
# Required problem keyword arguments
* `J_T`: A function `J_T(ϕ, trajectories)` that evaluates the final time
functional from a list `ϕ` of forward-propagated states and
`problem.trajectories`.
* `J_T`: A function `J_T(Ψ, trajectories)` that evaluates the final time
functional from a list `Ψ` of forward-propagated states and
`problem.trajectories`. The function `J_T` may also take a keyword argument
`tau`. If it does, a vector containing the complex overlaps of the target
states (`target_state` property of each trajectory in `problem.trajectories`)
with the propagated states will be passed to `J_T`.
# Recommended problem keyword arguments
Expand All @@ -53,10 +58,12 @@ The following keyword arguments are supported (with default values):
This overrides the global `lambda_a` and `update_shape` arguments.
* `chi`: A function `chi!(χ, ϕ, trajectories)` that receives a list `ϕ`
of the forward propagated states and must set ``|χₖ⟩ = -∂J_T/∂⟨ϕₖ|``. If not
given, it will be automatically determined from `J_T` via [`make_chi`](@ref
QuantumControlBase.make_chi) with the default parameters.
* `chi`: A function `chi(Ψ, trajectories)` that receives a list `Ψ`
of the forward propagated states and returns a vector of states
``|χₖ⟩ = -∂J_T/∂⟨Ψₖ|``. If not given, it will be automatically determined
from `J_T` via [`make_chi`](@ref) with the default parameters. Similarly to
`J_T`, if `chi` accepts a keyword argument `tau`, it will be passed a vector
of complex overlaps.
* `sigma=nothing`: A function that calculates the second-order contribution. If
not given, the first-order Krotov method is used.
* `iter_start=0`: The initial iteration number.
Expand Down Expand Up @@ -132,14 +139,6 @@ function optimize_krotov(problem)
ϵ⁽ⁱ⁾ = wrk.pulses0
ϵ⁽ⁱ⁺¹⁾ = wrk.pulses1

if haskey(wrk.kwargs, :chi)
chi! = wrk.kwargs[:chi]
else
# we only want to evaluate `make_chi` if `chi` is not a kwarg
J_T_func = wrk.kwargs[:J_T]
chi! = make_chi(J_T_func, wrk.trajectories)
end

if skip_initial_forward_propagation
@info "Skipping initial forward propagation"
else
Expand Down Expand Up @@ -170,7 +169,7 @@ function optimize_krotov(problem)
try
while !wrk.result.converged
i = i + 1
krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾)
update_result!(wrk, i)
info_tuple = callback(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
if !(isnothing(info_tuple) || isempty(info_tuple))
Expand Down Expand Up @@ -208,18 +207,18 @@ function transform_control_ranges(c, ϵ_min, ϵ_max, check)
end


function krotov_initial_fw_prop!(ϵ⁽⁰⁾, ϕₖⁱⁿ, k, wrk)
function krotov_initial_fw_prop!(ϵ⁽⁰⁾, ϕₖ, k, wrk)
for propagator in wrk.fw_propagators
propagator.parameters = IdDict(zip(wrk.controls, ϵ⁽⁰⁾))
end
reinit_prop!(wrk.fw_propagators[k], ϕₖⁱⁿ; transform_control_ranges)
reinit_prop!(wrk.fw_propagators[k], ϕₖ; transform_control_ranges)

Φ₀ = wrk.fw_storage[k]
(Φ₀ !== nothing) && write_to_storage!(Φ₀, 1, ϕₖⁱⁿ)
(Φ₀ !== nothing) && write_to_storage!(Φ₀, 1, ϕₖ)
N_T = length(wrk.result.tlist) - 1
for n = 1:N_T
ϕₖ = prop_step!(wrk.fw_propagators[k])
(Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, ϕₖ)
Ψₖ = prop_step!(wrk.fw_propagators[k])
(Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, Ψₖ)
end
# TODO: allow a custom prop_step! routine
end
Expand All @@ -236,12 +235,8 @@ _eval_mu(μ::Operator, _...) = μ
_eval_mu::AbstractMatrix, _...) = μ


function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾)

χ = [
(ismutable(propagator.state) ? propagator.state : similar(propagator.state)) for
propagator in wrk.bw_propagators
]
tlist = wrk.result.tlist
N_T = length(tlist) - 1
N = length(wrk.trajectories)
Expand All @@ -250,14 +245,22 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
Φ = wrk.fw_storage # TODO: pass in Φ₁, Φ₀ as parameters
∫gₐdt = wrk.g_a_int
Im = imag
chi = wrk.kwargs[:chi] # guaranteed to exist in `KrotovWrk` constructor


guess_parameters = IdDict(zip(wrk.controls, ϵ⁽ⁱ⁾))
updated_parameters = IdDict(zip(wrk.controls, ϵ⁽ⁱ⁺¹⁾))

# backward propagation
ϕ = [propagator.state for propagator in wrk.fw_propagators]
chi!(χ, ϕ, wrk.trajectories)

Ψ = [propagator.state for propagator in wrk.fw_propagators]
if wrk.chi_takes_tau
χ = chi(Ψ, wrk.trajectories; tau=wrk.result.tau_vals)
else
χ = chi(Ψ, wrk.trajectories)

Check warning on line 260 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L260

Added line #L260 was not covered by tests
end
@threadsif wrk.use_threads for k = 1:N
# TODO: normalize χ; warn if norm is close to zero
wrk.bw_propagators[k].parameters = guess_parameters
reinit_prop!(wrk.bw_propagators[k], χ[k]; transform_control_ranges)
write_to_storage!(X[k], N_T + 1, χ[k])
Expand All @@ -271,26 +274,30 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)

@threadsif wrk.use_threads for k = 1:N
wrk.fw_propagators[k].parameters = updated_parameters
local ϕₖ = wrk.trajectories[k].initial_state
reinit_prop!(wrk.fw_propagators[k], ϕₖ; transform_control_ranges)
local Ψₖ = wrk.trajectories[k].initial_state
reinit_prop!(wrk.fw_propagators[k], Ψₖ; transform_control_ranges)
end

∫gₐdt .= 0.0
for n = 1:N_T # `n` is the index for the time interval
dt = tlist[n+1] - tlist[n]
for k = 1:N
get_from_storage!(χ[k], X[k], n)
if ismutable(χ[k])
get_from_storage!(χ[k], X[k], n)
else
χ[k] = get_from_storage(X[k], n)

Check warning on line 288 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L288

Added line #L288 was not covered by tests
end
end
ϵₙ⁽ⁱ⁺¹⁾ = [ϵ⁽ⁱ⁾[l][n] for l 1:L] # ϵₙ⁽ⁱ⁺¹⁾ ≈ ϵₙ⁽ⁱ⁾ for non-linear controls
# TODO: we could add a self-consistent loop here for ϵₙ⁽ⁱ⁺¹⁾
Δuₙ = zeros(L) # Δu is Δϵ without (Sₗₙ/λₐ) factor
for l = 1:L # `l` is the index for the different controls
for k = 1:N # k is the index over the trajectories
ϕₖ = wrk.fw_propagators[k].state
Ψₖ = wrk.fw_propagators[k].state
μₖₗ = wrk.control_derivs[k][l]
if !isnothing(μₖₗ)
μₗₖₙ = _eval_mu(μₖₗ, wrk, ϵₙ⁽ⁱ⁺¹⁾, tlist, n)
Δuₙ[l] += Im(dot(χ[k], μₗₖₙ, ϕₖ))
Δuₙ[l] += Im(dot(χ[k], μₗₖₙ, Ψₖ))
end
end
end
Expand All @@ -305,8 +312,8 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
end
# TODO: end of self-consistent loop
@threadsif wrk.use_threads for k = 1:N
local ϕₖ = prop_step!(wrk.fw_propagators[k])
write_to_storage!(Φ[k], n, ϕₖ)
local Ψₖ = prop_step!(wrk.fw_propagators[k])
write_to_storage!(Φ[k], n, Ψₖ)
end
# TODO: update sigma
end # time loop
Expand All @@ -315,12 +322,17 @@ end

function update_result!(wrk::KrotovWrk, i::Int64)
res = wrk.result
J_T_func = wrk.kwargs[:J_T]
J_T = wrk.kwargs[:J_T]
res.J_T_prev = res.J_T
for k in eachindex(wrk.fw_propagators)
res.states[k] = wrk.fw_propagators[k].state
end
res.J_T = J_T_func(res.states, wrk.trajectories)
taus!(res.tau_vals, res.states, wrk.trajectories; ignore_missing_target_state=true)
if wrk.J_T_takes_tau
res.J_T = J_T(res.states, wrk.trajectories; tau=res.tau_vals)
else
res.J_T = J_T(res.states, wrk.trajectories)

Check warning on line 334 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L334

Added line #L334 was not covered by tests
end
(i > 0) && (res.iter = i)
if i >= res.iter_stop
res.converged = true
Expand All @@ -331,7 +343,6 @@ function update_result!(wrk::KrotovWrk, i::Int64)
prev_time = res.end_local_time
res.end_local_time = now()
res.secs = Dates.toms(res.end_local_time - prev_time) / 1000.0
# TODO: calculate τ values
end


Expand Down
17 changes: 10 additions & 7 deletions src/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ The attributes of a `KrotovResult` object include
* `iter`: The number of the current iteration
* `J_T`: The value of the final-time functional in the current iteration
* `J_T_prev`: The value of the final-time functional in the previous iteration
* `tlist`: The time grid on which the control are discetized.
* `tlist`: The time grid on which the control are discretized.
* `guess_controls`: A vector of the original control fields (each field
discretized to the points of `tlist`)
* optimized_controls: A vector of the optimized control fileds in the current
iterations
* records: A vector of tuples with values returned by a `callback` routine
* `optimized_controls`: A vector of the optimized control fields. Calculated only
at the end of the optimization, not after each iteration.
* `tau_vals`: For any trajectory that defines a `target_state`, the complex
overlap of that target state with the propagated state. For any trajectory
for which the `target_state` is `nothing`, the value is zero.
* `records`: A vector of tuples with values returned by a `callback` routine
passed to [`optimize`](@ref)
* converged: A boolean flag on whether the optimization is converged. This
* `converged`: A boolean flag on whether the optimization is converged. This
may be set to `true` by a `check_convergence` function.
* message: A message string to explain the reason for convergence. This may be
* `message`: A message string to explain the reason for convergence. This may be
set by a `check_convergence` function.
All of the above attributes may be referenced in a `check_convergence` function
Expand Down Expand Up @@ -53,7 +56,7 @@ function KrotovResult(problem)
iter_stop = get(problem.kwargs, :iter_stop, 5000)
iter = iter_start
secs = 0
tau_vals = Vector{ComplexF64}()
tau_vals = zeros(ComplexF64, length(problem.trajectories))
guess_controls = [discretize(control, tlist) for control in controls]
J_T = 0.0
J_T_prev = 0.0
Expand Down
20 changes: 19 additions & 1 deletion src/workspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ mutable struct KrotovWrk
g_a_int::Vector{Float64}
update_shapes::Vector{Vector{Float64}}
lambda_vals::Vector{Float64}
J_T_takes_tau::Bool # Does J_T have a tau keyword arg?
chi_takes_tau::Bool # Does chi have a tau keyword arg?
# map of controls to options
result

Expand Down Expand Up @@ -150,6 +152,21 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false)
kwargs...
) for (k, traj) in enumerate(adjoint_trajectories)
]
J_T_takes_tau = false
if haskey(kwargs, :J_T)
J_T = kwargs[:J_T]
else
msg = "`optimize` for `method=Krotov` must be passed the functional `J_T`."
throw(ArgumentError(msg))

Check warning on line 160 in src/workspace.jl

View check run for this annotation

Codecov / codecov/patch

src/workspace.jl#L159-L160

Added lines #L159 - L160 were not covered by tests
end
J_T_takes_tau =
hasmethod(J_T, Tuple{typeof(result.states),typeof(trajectories)}, (:tau,))
if !haskey(kwargs, :chi)
kwargs[:chi] = make_chi(J_T, trajectories)
end
chi = kwargs[:chi]
chi_takes_tau =
hasmethod(chi, Tuple{typeof(result.states),typeof(trajectories)}, (:tau,))
return KrotovWrk(
trajectories,
adjoint_trajectories,
Expand All @@ -160,7 +177,8 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false)
g_a_int,
update_shapes,
lambda_vals,
# pulse_options, # XXX
J_T_takes_tau,
chi_takes_tau,
result,
control_derivs,
fw_storage,
Expand Down

0 comments on commit a8ddddf

Please sign in to comment.