Skip to content

Commit

Permalink
Combine info-hook/update-hook into callback
Browse files Browse the repository at this point in the history
See JuliaQuantumControl/QuantumControlBase.jl#84

This also enables the new `store_iter_info` argument.
  • Loading branch information
goerz committed Jul 9, 2024
1 parent 2f7c298 commit 54fa928
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 107 deletions.
198 changes: 126 additions & 72 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using QuantumControlBase: @threadsif, Trajectory
using LinearAlgebra
using Printf

import QuantumControlBase: optimize
import QuantumControlBase: optimize, make_print_iters

@doc raw"""
```julia
Expand All @@ -19,7 +19,7 @@ optimizes the given control [`problem`](@ref QuantumControlBase.ControlProblem)
using Krotov's method, returning a [`KrotovResult`](@ref).
Keyword arguments that control the optimization are taken from the keyword
arguments used in the instantiation of `problem`; any of these can be overriden
arguments used in the instantiation of `problem`; any of these can be overridden
with explicit keyword arguments to `optimize`.
# Required problem keyword arguments
Expand All @@ -30,13 +30,13 @@ with explicit keyword arguments to `optimize`.
# Recommended problem keyword arguments
* `lambda_a=1.0`: The inverse Krotov step width λ_a for every pulse.
* `lambda_a=1.0`: The inverse Krotov step width λₐ for every pulse.
* `update_shape=(t->1.0)`: A function `S(t)` for the "update shape" that scales
the update for every pulse
the update for every pulse.
If different controls require different `lambda_a` or `update_shape`, a dict
`pulse_options` must be given instead of a global `lambda_a` and
`update_shape`, see below.
`update_shape`; see below.
# Optional problem keyword arguments
Expand All @@ -47,40 +47,43 @@ The following keyword arguments are supported (with default values):
QuantumControlBase.QuantumPropagators.Controls.get_controls) from the
`problem.trajectories`) to the following dict:
- `:lambda_a`: The value for inverse Krotov step width λₐ
- `:lambda_a`: The value for inverse Krotov step width λₐ.
- `:update_shape`: A function `S(t)` for the "update shape" that scales
the Krotov pulse update.
This overrides the global `lambda_a` and `update_shape` arguments.
* `chi`: A function `chi!(χ, ϕ, trajectories)` what receives a list `ϕ`
* `chi`: A function `chi!(χ, ϕ, trajectories)` that receives a list `ϕ`
of the forward propagated states and must set ``|χₖ⟩ = -∂J_T/∂⟨ϕₖ|``. If not
given, it will be automatically determined from `J_T` via [`make_chi`](@ref
QuantumControlBase.make_chi) with the default parameters.
* `sigma=nothing`: Function that calculate the second-order contribution. If
* `sigma=nothing`: A function that calculates the second-order contribution. If
not given, the first-order Krotov method is used.
* `iter_start=0`: the initial iteration number
* `iter_stop=5000`: the maximum iteration number
* `prop_method`: The propagation method to use for each trajectory, see below.
* `update_hook`: A function that receives the Krotov workspace, the iteration
number, the list of updated pulses and the list of guess pulses as
positional arguments. The function may mutate any of its arguments. This may
be used e.g. to apply a spectral filter to the updated pulses or to perform
similar manipulations.
* `info_hook`: A function (or tuple of functions) that receives the same
arguments as `update_hook`, in order to write information about the current
iteration to the screen or to a file. The default `info_hook` prints a table
with convergence information to the screen. Runs after `update_hook`. The
`info_hook` function may return a tuple, which is stored in the list of
`records` inside the [`KrotovResult`](@ref) object.
* `check_convergence`: a function to check whether convergence has been
* `iter_start=0`: The initial iteration number.
* `iter_stop=5000`: The maximum iteration number.
* `prop_method`: The propagation method to use for each trajectory; see below.
* `print_iters=true`: Whether to print information after each iteration.
* `store_iter_info=Set()`: Which fields from `print_iters` to store in
`result.records`. A subset of
`Set(["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"])`.
* `callback`: A function (or tuple of functions) that receives the
[Krotov workspace](@ref KrotovWrk), the iteration number, the list of updated
pulses, and the list of guess pulses as positional arguments. The function
may return a tuple of values which are stored in the
[`KrotovResult`](@ref) object `result.records`. The function can also mutate
any of its arguments, in particular the updated pulses. This may be used,
e.g., to apply a spectral filter to the updated pulses or to perform
similar manipulations. Note that `print_iters=true` (default) adds an
automatic callback to print information after each iteration. With
`store_iter_info`, that callback automatically stores a subset of the
printed information.
* `check_convergence`: A function to check whether convergence has been
reached. Receives a [`KrotovResult`](@ref) object `result`, and should set
`result.converged` to `true` and `result.message` to an appropriate string in
case of convergence. Multiple convergence checks can be performed by chaining
functions with `∘`. The convergence check is performed after any calls to
`update_hook` and `info_hook`.
* `verbose=false`: If `true`, print information during initialization
* `rethrow_exceptions`: By default, any exception ends the optimization, but
functions with `∘`. The convergence check is performed after any `callback`.
* `verbose=false`: If `true`, print information during initialization.
* `rethrow_exceptions`: By default, any exception ends the optimization but
still returns a [`KrotovResult`](@ref) that captures the message associated
with the exception. This is to avoid losing results from a long-running
optimization when an exception occurs in a later iteration. If
Expand All @@ -97,7 +100,7 @@ each [`Trajectory`](@ref) that have a `prop_` prefix, cf.
In situations where different parameters are required for the forward and
backward propagation, instead of the `prop_` prefix, the `fw_prop_` and
`bw_prop_` prefix can be used, respectively. These override any setting with
`bw_prop_` prefixes can be used, respectively. These override any setting with
the `prop_` prefix. This applies both to the properties of each
[`Trajectory`](@ref) and the problem keyword arguments.
Expand All @@ -112,10 +115,11 @@ optimize(problem, method::Val{:krotov}) = optimize_krotov(problem)
See [`optimize(problem; method=Krotov, kwargs...)`](@ref optimize(::Any, ::Val{:krotov})).
"""
function optimize_krotov(problem)
sigma = get(problem.kwargs, :sigma, nothing)
iter_start = get(problem.kwargs, :iter_start, 0)
update_hook! = get(problem.kwargs, :update_hook, (args...) -> nothing)
info_hook = get(problem.kwargs, :info_hook, print_table)
callback = get(problem.kwargs, :callback, (args...) -> nothing)
if haskey(problem.kwargs, :update_hook) || haskey(problem.kwargs, :info_hook)
msg = "The `update_hook` and `info_hook` arguments have been superseded by the `callback` argument"
throw(ArgumentError(msg))
end
check_convergence! = get(problem.kwargs, :check_convergence, res -> res)
# note: the default `check_convergence!` is a no-op. We still always check
# for "Reached maximum number of iterations" in `update_result!`
Expand Down Expand Up @@ -146,9 +150,10 @@ function optimize_krotov(problem)

# TODO: if sigma, fw_storage0 = fw_storage
update_result!(wrk, 0)
update_hook!(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
info_tuple = info_hook(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
info_tuple = callback(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
if !(isnothing(info_tuple) || isempty(info_tuple))
push!(wrk.result.records, info_tuple)
end

i = wrk.result.iter # = 0, unless continuing from previous optimization
atexit_filename = get(problem.kwargs, :atexit_filename, nothing)
Expand All @@ -167,9 +172,10 @@ function optimize_krotov(problem)
i = i + 1
krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
update_result!(wrk, i)
update_hook!(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
info_tuple = info_hook(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
info_tuple = callback(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾)
if !(isnothing(info_tuple) || isempty(info_tuple))
push!(wrk.result.records, info_tuple)
end
check_convergence!(wrk.result)
ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾ = ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾
end
Expand Down Expand Up @@ -236,7 +242,6 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!)
(ismutable(propagator.state) ? propagator.state : similar(propagator.state)) for
propagator in wrk.bw_propagators
]
J_T_func = wrk.kwargs[:J_T]
tlist = wrk.result.tlist
N_T = length(tlist) - 1
N = length(wrk.trajectories)
Expand Down Expand Up @@ -339,42 +344,91 @@ function finalize_result!(ϵ_opt, wrk::KrotovWrk)
end


"""Print optimization progress as a table.
make_print_iters(::Val{:Krotov}; kwargs...) = make_krotov_print_iters(; kwargs...)
make_print_iters(::Val{:krotov}; kwargs...) = make_krotov_print_iters(; kwargs...)


function make_krotov_print_iters(; kwargs...)

header = ["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"]
store_iter_info = Set(get(kwargs, :store_iter_info, Set()))
info_vals = Vector{Any}(undef, length(header))
fill!(info_vals, nothing)
store_iter = false
store_J_T = false
store_g_a_int = false
store_J = false
store_ΔJ_T = false
store_ΔJ = false
store_secs = false
for item in store_iter_info
if item == "iter."
store_iter = true
elseif item == "J_T"
store_J_T = true
elseif item == "∫gₐ(t)dt"
store_g_a_int = true
elseif item == "J"
store_J = true
elseif item == "ΔJ_T"
store_ΔJ_T = true
elseif item == "ΔJ"
store_ΔJ = true
elseif item == "secs"
store_secs = true
else
msg = "Item $(repr(item)) in `store_iter_info` is not one of $(repr(header)))"
throw(ArgumentError(msg))
end
end


This functions serves as the default `info_hook` for an optimization with
Krotov's method.
"""
function print_table(wrk, iteration, args...)
J_T = wrk.result.J_T
g_a_int = sum(wrk.g_a_int)
J = J_T + g_a_int
ΔJ_T = J_T - wrk.result.J_T_prev
ΔJ = ΔJ_T + g_a_int
secs = wrk.result.secs

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
widths = [max(length("$iter_stop"), 6), 11, 11, 11, 11, 11, 8]

if iteration == 0
header = ["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"]
for (header, w) in zip(header, widths)
print(lpad(header, w))
function print_table(wrk, iteration, args...)

J_T = wrk.result.J_T
g_a_int = sum(wrk.g_a_int)
J = J_T + g_a_int
ΔJ_T = J_T - wrk.result.J_T_prev
ΔJ = ΔJ_T + g_a_int
secs = wrk.result.secs

store_iter && (info_vals[1] = iteration)
store_J_T && (info_vals[2] = J_T)
store_g_a_int && (info_vals[3] = g_a_int)
store_J && (info_vals[4] = J)
store_ΔJ_T && (info_vals[5] = ΔJ_T)
store_ΔJ && (info_vals[6] = ΔJ)
store_secs && (info_vals[7] = secs)

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
widths = [max(length("$iter_stop"), 6), 11, 11, 11, 11, 11, 8]

if iteration == 0
for (header, w) in zip(header, widths)
print(lpad(header, w))
end
print("\n")
end

strs = (
"$iteration",
@sprintf("%.2e", J_T),
@sprintf("%.2e", g_a_int),
@sprintf("%.2e", J),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
(iteration > 0) ? @sprintf("%.2e", ΔJ) : "n/a",
@sprintf("%.1f", secs),
)
for (str, w) in zip(strs, widths)
print(lpad(str, w))
end
print("\n")
end
flush(stdout)

return Tuple((value for value in info_vals if (value !== nothing)))

strs = (
"$iteration",
@sprintf("%.2e", J_T),
@sprintf("%.2e", g_a_int),
@sprintf("%.2e", J),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
(iteration > 0) ? @sprintf("%.2e", ΔJ) : "n/a",
@sprintf("%.1f", secs),
)
for (str, w) in zip(strs, widths)
print(lpad(str, w))
end
print("\n")
flush(stdout)

return print_table

end
4 changes: 2 additions & 2 deletions src/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The attributes of a `KrotovResult` object include
discretized to the points of `tlist`)
* optimized_controls: A vector of the optimized control fileds in the current
iterations
* records: A vector of tuples with values returned by an `info_hook` routine
* records: A vector of tuples with values returned by a `callback` routine
passed to [`optimize`](@ref)
* converged: A boolean flag on whether the optimization is converged. This
may be set to `true` by a `check_convergence` function.
Expand All @@ -41,7 +41,7 @@ mutable struct KrotovResult{STST}
states::Vector{STST} # the forward-propagated states after each iteration
start_local_time::DateTime
end_local_time::DateTime
records::Vector{Tuple} # storage for info_hook to write data into at each iteration
records::Vector{Tuple} # storage for callback to write data into at each iteration
converged::Bool
message::String
end
Expand Down
Loading

0 comments on commit 54fa928

Please sign in to comment.