diff --git a/src/optimize.jl b/src/optimize.jl index 5c88cd8..a6eb3c3 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -7,7 +7,7 @@ using QuantumControlBase: @threadsif, Trajectory using LinearAlgebra using Printf -import QuantumControlBase: optimize +import QuantumControlBase: optimize, make_print_iters @doc raw""" ```julia @@ -19,7 +19,7 @@ optimizes the given control [`problem`](@ref QuantumControlBase.ControlProblem) using Krotov's method, returning a [`KrotovResult`](@ref). Keyword arguments that control the optimization are taken from the keyword -arguments used in the instantiation of `problem`; any of these can be overriden +arguments used in the instantiation of `problem`; any of these can be overridden with explicit keyword arguments to `optimize`. # Required problem keyword arguments @@ -30,13 +30,13 @@ with explicit keyword arguments to `optimize`. # Recommended problem keyword arguments -* `lambda_a=1.0`: The inverse Krotov step width λ_a for every pulse. +* `lambda_a=1.0`: The inverse Krotov step width λₐ for every pulse. * `update_shape=(t->1.0)`: A function `S(t)` for the "update shape" that scales - the update for every pulse + the update for every pulse. If different controls require different `lambda_a` or `update_shape`, a dict `pulse_options` must be given instead of a global `lambda_a` and -`update_shape`, see below. +`update_shape`; see below. # Optional problem keyword arguments @@ -47,40 +47,43 @@ The following keyword arguments are supported (with default values): QuantumControlBase.QuantumPropagators.Controls.get_controls) from the `problem.trajectories`) to the following dict: - - `:lambda_a`: The value for inverse Krotov step width λₐ + - `:lambda_a`: The value for inverse Krotov step width λₐ. - `:update_shape`: A function `S(t)` for the "update shape" that scales the Krotov pulse update. This overrides the global `lambda_a` and `update_shape` arguments. -* `chi`: A function `chi!(χ, ϕ, trajectories)` what receives a list `ϕ` +* `chi`: A function `chi!(χ, ϕ, trajectories)` that receives a list `ϕ` of the forward propagated states and must set ``|χₖ⟩ = -∂J_T/∂⟨ϕₖ|``. If not given, it will be automatically determined from `J_T` via [`make_chi`](@ref QuantumControlBase.make_chi) with the default parameters. -* `sigma=nothing`: Function that calculate the second-order contribution. If +* `sigma=nothing`: A function that calculates the second-order contribution. If not given, the first-order Krotov method is used. -* `iter_start=0`: the initial iteration number -* `iter_stop=5000`: the maximum iteration number -* `prop_method`: The propagation method to use for each trajectory, see below. -* `update_hook`: A function that receives the Krotov workspace, the iteration - number, the list of updated pulses and the list of guess pulses as - positional arguments. The function may mutate any of its arguments. This may - be used e.g. to apply a spectral filter to the updated pulses or to perform - similar manipulations. -* `info_hook`: A function (or tuple of functions) that receives the same - arguments as `update_hook`, in order to write information about the current - iteration to the screen or to a file. The default `info_hook` prints a table - with convergence information to the screen. Runs after `update_hook`. The - `info_hook` function may return a tuple, which is stored in the list of - `records` inside the [`KrotovResult`](@ref) object. -* `check_convergence`: a function to check whether convergence has been +* `iter_start=0`: The initial iteration number. +* `iter_stop=5000`: The maximum iteration number. +* `prop_method`: The propagation method to use for each trajectory; see below. +* `print_iters=true`: Whether to print information after each iteration. +* `store_iter_info=Set()`: Which fields from `print_iters` to store in + `result.records`. A subset of + `Set(["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"])`. +* `callback`: A function (or tuple of functions) that receives the + [Krotov workspace](@ref KrotovWrk), the iteration number, the list of updated + pulses, and the list of guess pulses as positional arguments. The function + may return a tuple of values which are stored in the + [`KrotovResult`](@ref) object `result.records`. The function can also mutate + any of its arguments, in particular the updated pulses. This may be used, + e.g., to apply a spectral filter to the updated pulses or to perform + similar manipulations. Note that `print_iters=true` (default) adds an + automatic callback to print information after each iteration. With + `store_iter_info`, that callback automatically stores a subset of the + printed information. +* `check_convergence`: A function to check whether convergence has been reached. Receives a [`KrotovResult`](@ref) object `result`, and should set `result.converged` to `true` and `result.message` to an appropriate string in case of convergence. Multiple convergence checks can be performed by chaining - functions with `∘`. The convergence check is performed after any calls to - `update_hook` and `info_hook`. -* `verbose=false`: If `true`, print information during initialization -* `rethrow_exceptions`: By default, any exception ends the optimization, but + functions with `∘`. The convergence check is performed after any `callback`. +* `verbose=false`: If `true`, print information during initialization. +* `rethrow_exceptions`: By default, any exception ends the optimization but still returns a [`KrotovResult`](@ref) that captures the message associated with the exception. This is to avoid losing results from a long-running optimization when an exception occurs in a later iteration. If @@ -97,7 +100,7 @@ each [`Trajectory`](@ref) that have a `prop_` prefix, cf. In situations where different parameters are required for the forward and backward propagation, instead of the `prop_` prefix, the `fw_prop_` and -`bw_prop_` prefix can be used, respectively. These override any setting with +`bw_prop_` prefixes can be used, respectively. These override any setting with the `prop_` prefix. This applies both to the properties of each [`Trajectory`](@ref) and the problem keyword arguments. @@ -112,10 +115,11 @@ optimize(problem, method::Val{:krotov}) = optimize_krotov(problem) See [`optimize(problem; method=Krotov, kwargs...)`](@ref optimize(::Any, ::Val{:krotov})). """ function optimize_krotov(problem) - sigma = get(problem.kwargs, :sigma, nothing) - iter_start = get(problem.kwargs, :iter_start, 0) - update_hook! = get(problem.kwargs, :update_hook, (args...) -> nothing) - info_hook = get(problem.kwargs, :info_hook, print_table) + callback = get(problem.kwargs, :callback, (args...) -> nothing) + if haskey(problem.kwargs, :update_hook) || haskey(problem.kwargs, :info_hook) + msg = "The `update_hook` and `info_hook` arguments have been superseded by the `callback` argument" + throw(ArgumentError(msg)) + end check_convergence! = get(problem.kwargs, :check_convergence, res -> res) # note: the default `check_convergence!` is a no-op. We still always check # for "Reached maximum number of iterations" in `update_result!` @@ -146,9 +150,10 @@ function optimize_krotov(problem) # TODO: if sigma, fw_storage0 = fw_storage update_result!(wrk, 0) - update_hook!(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) - info_tuple = info_hook(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) - (info_tuple !== nothing) && push!(wrk.result.records, info_tuple) + info_tuple = callback(wrk, 0, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) + if !(isnothing(info_tuple) || isempty(info_tuple)) + push!(wrk.result.records, info_tuple) + end i = wrk.result.iter # = 0, unless continuing from previous optimization atexit_filename = get(problem.kwargs, :atexit_filename, nothing) @@ -167,9 +172,10 @@ function optimize_krotov(problem) i = i + 1 krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) update_result!(wrk, i) - update_hook!(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) - info_tuple = info_hook(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) - (info_tuple !== nothing) && push!(wrk.result.records, info_tuple) + info_tuple = callback(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) + if !(isnothing(info_tuple) || isempty(info_tuple)) + push!(wrk.result.records, info_tuple) + end check_convergence!(wrk.result) ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾ = ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾ end @@ -236,7 +242,6 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) (ismutable(propagator.state) ? propagator.state : similar(propagator.state)) for propagator in wrk.bw_propagators ] - J_T_func = wrk.kwargs[:J_T] tlist = wrk.result.tlist N_T = length(tlist) - 1 N = length(wrk.trajectories) @@ -339,42 +344,91 @@ function finalize_result!(ϵ_opt, wrk::KrotovWrk) end -"""Print optimization progress as a table. +make_print_iters(::Val{:Krotov}; kwargs...) = make_krotov_print_iters(; kwargs...) +make_print_iters(::Val{:krotov}; kwargs...) = make_krotov_print_iters(; kwargs...) + + +function make_krotov_print_iters(; kwargs...) + + header = ["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"] + store_iter_info = Set(get(kwargs, :store_iter_info, Set())) + info_vals = Vector{Any}(undef, length(header)) + fill!(info_vals, nothing) + store_iter = false + store_J_T = false + store_g_a_int = false + store_J = false + store_ΔJ_T = false + store_ΔJ = false + store_secs = false + for item in store_iter_info + if item == "iter." + store_iter = true + elseif item == "J_T" + store_J_T = true + elseif item == "∫gₐ(t)dt" + store_g_a_int = true + elseif item == "J" + store_J = true + elseif item == "ΔJ_T" + store_ΔJ_T = true + elseif item == "ΔJ" + store_ΔJ = true + elseif item == "secs" + store_secs = true + else + msg = "Item $(repr(item)) in `store_iter_info` is not one of $(repr(header)))" + throw(ArgumentError(msg)) + end + end + -This functions serves as the default `info_hook` for an optimization with -Krotov's method. -""" -function print_table(wrk, iteration, args...) - J_T = wrk.result.J_T - g_a_int = sum(wrk.g_a_int) - J = J_T + g_a_int - ΔJ_T = J_T - wrk.result.J_T_prev - ΔJ = ΔJ_T + g_a_int - secs = wrk.result.secs - - iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))" - widths = [max(length("$iter_stop"), 6), 11, 11, 11, 11, 11, 8] - - if iteration == 0 - header = ["iter.", "J_T", "∫gₐ(t)dt", "J", "ΔJ_T", "ΔJ", "secs"] - for (header, w) in zip(header, widths) - print(lpad(header, w)) + function print_table(wrk, iteration, args...) + + J_T = wrk.result.J_T + g_a_int = sum(wrk.g_a_int) + J = J_T + g_a_int + ΔJ_T = J_T - wrk.result.J_T_prev + ΔJ = ΔJ_T + g_a_int + secs = wrk.result.secs + + store_iter && (info_vals[1] = iteration) + store_J_T && (info_vals[2] = J_T) + store_g_a_int && (info_vals[3] = g_a_int) + store_J && (info_vals[4] = J) + store_ΔJ_T && (info_vals[5] = ΔJ_T) + store_ΔJ && (info_vals[6] = ΔJ) + store_secs && (info_vals[7] = secs) + + iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))" + widths = [max(length("$iter_stop"), 6), 11, 11, 11, 11, 11, 8] + + if iteration == 0 + for (header, w) in zip(header, widths) + print(lpad(header, w)) + end + print("\n") + end + + strs = ( + "$iteration", + @sprintf("%.2e", J_T), + @sprintf("%.2e", g_a_int), + @sprintf("%.2e", J), + (iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a", + (iteration > 0) ? @sprintf("%.2e", ΔJ) : "n/a", + @sprintf("%.1f", secs), + ) + for (str, w) in zip(strs, widths) + print(lpad(str, w)) end print("\n") - end + flush(stdout) + + return Tuple((value for value in info_vals if (value !== nothing))) - strs = ( - "$iteration", - @sprintf("%.2e", J_T), - @sprintf("%.2e", g_a_int), - @sprintf("%.2e", J), - (iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a", - (iteration > 0) ? @sprintf("%.2e", ΔJ) : "n/a", - @sprintf("%.1f", secs), - ) - for (str, w) in zip(strs, widths) - print(lpad(str, w)) end - print("\n") - flush(stdout) + + return print_table + end diff --git a/src/result.jl b/src/result.jl index 221dbdf..ab92844 100644 --- a/src/result.jl +++ b/src/result.jl @@ -17,7 +17,7 @@ The attributes of a `KrotovResult` object include discretized to the points of `tlist`) * optimized_controls: A vector of the optimized control fileds in the current iterations -* records: A vector of tuples with values returned by an `info_hook` routine +* records: A vector of tuples with values returned by a `callback` routine passed to [`optimize`](@ref) * converged: A boolean flag on whether the optimization is converged. This may be set to `true` by a `check_convergence` function. @@ -41,7 +41,7 @@ mutable struct KrotovResult{STST} states::Vector{STST} # the forward-propagated states after each iteration start_local_time::DateTime end_local_time::DateTime - records::Vector{Tuple} # storage for info_hook to write data into at each iteration + records::Vector{Tuple} # storage for callback to write data into at each iteration converged::Bool message::String end diff --git a/src/workspace.jl b/src/workspace.jl index e0fabee..fff90a3 100644 --- a/src/workspace.jl +++ b/src/workspace.jl @@ -3,61 +3,56 @@ using QuantumControlBase: get_control_derivs using QuantumControlBase.QuantumPropagators.Controls: get_controls, discretize_on_midpoints using QuantumControlBase: init_prop_trajectory using QuantumControlBase.QuantumPropagators.Storage: init_storage -using ConcreteStructs +using ConcreteStructs: @concrete + +"""Krotov workspace. + +The workspace is for internal use. However, it is also accessible in a +`callback` function. The callback may use or modify some of the following +attributes: + +* `trajectories`: a copy of the trajectories defining the control problem +* `adjoint_trajectories`: The `trajectories` with the adjoint generator +* `kwargs`: The keyword arguments from the [`ControlProblem`](@ref) or the + call to [`optimize`](@ref). +* `controls`: A tuple of the original controls (probably functions) +* `ga_a_int`: The current value of ``∫gₐ(t)dt`` for each control +* `update_shapes`: The update shapes ``S(t)`` for each pulse, discretized on + the intervals of the time grid. +* `lambda_vals`: The current value of λₐ for each control +* `result`: The current result object +* `fw_storage`: The storage of states for the forward propagation +* `fw_propagators`: The propagators used for the forward propagation +* `bw_propagators`: The propagators used for the backward propagation +* `use_threads`: Flag indicating whether the propagations are performed in + parallel. +""" +mutable struct KrotovWrk -# Krotov workspace (for internal use) -@concrete terse struct KrotovWrk - - # a copy of the trajectories trajectories - - # the adjoint trajectories, containing the adjoint generators for the - # backward propagation adjoint_trajectories - - # The kwargs from the control problem kwargs - - # Tuple of the original controls (probably functions) controls - # storage for controls discretized on intervals of tlist pulses0::Vector{Vector{Float64}} - # second pulse storage: pulses0 and pulses1 alternate in storing the guess # pulses and optimized pulses in each iteration pulses1::Vector{Vector{Float64}} - - # values of ∫gₐ(t)dt for each pulse g_a_int::Vector{Float64} - - # update shapes S(t) for each pulse, discretized on intervals update_shapes::Vector{Vector{Float64}} - lambda_vals::Vector{Float64} - # map of controls to options - pulse_options - - # Result object - result ################################# # scratch objects, per trajectory: control_derivs - fw_storage # forward storage array (per trajectory) - fw_storage2 # forward storage array (per trajectory) - bw_storage # backward storage array (per trajectory) - fw_propagators - bw_propagators - use_threads::Bool end @@ -85,7 +80,7 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false) @warn("`update_shape` is ignored due to given `pulse_options`") end if :lambda_a in keys(kwargs) - @warn("`lambda_a=$lambda_a` is ignored due to given `pulse_options`") + @warn("`lambda_a=$(kwargs[:lambda_a])` is ignored due to given `pulse_options`") end else if (:update_shape ∉ keys(kwargs)) && (:lambda_a ∉ keys(kwargs)) @@ -165,7 +160,7 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false) g_a_int, update_shapes, lambda_vals, - pulse_options, + # pulse_options, # XXX result, control_derivs, fw_storage, diff --git a/test/Project.toml b/test/Project.toml index 7f4511c..f618bb6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krotov = "b05dcdc7-62f6-4360-bf2c-0898bba419de" LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433" diff --git a/test/runtests.jl b/test/runtests.jl index 1d4c69f..3d41507 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,5 +17,10 @@ unicodeplots() include("test_empty_optimization.jl") end + println("\n* Iterations (test_iterations.jl)") + @time @safetestset "Iterations" begin + include("test_iterations.jl") + end + end nothing diff --git a/test/test_iterations.jl b/test/test_iterations.jl new file mode 100644 index 0000000..3bef092 --- /dev/null +++ b/test/test_iterations.jl @@ -0,0 +1,145 @@ +using Test +using QuantumControl: optimize +using StableRNGs +using LinearAlgebra: norm +using LinearAlgebra.BLAS: scal! +using Krotov +using QuantumControlTestUtils.DummyOptimization: dummy_control_problem +using QuantumControl.Functionals: J_T_ss +using IOCapture + +PASSTHROUGH = false + +@testset "iter_start_stop" begin + # Test that setting iter_start and iter_stop in fact restricts the + # optimization to those numbers + rng = StableRNG(1244568944) + problem = dummy_control_problem(; + iter_start=10, + N=2, + density=1.0, + complex_operators=false, + rng, + J_T=J_T_ss, + store_iter_info=["iter.", "J_T"] + ) + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=Krotov, iter_stop=12) + end + res = captured.value + @test res.converged + @test res.iter_start == 10 + @test res.iter_stop == 12 + iters = [values[1] for values in res.records] + @test iters == [0, 11, 12] +end + + +@testset "callback" begin + + rng = StableRNG(1244568944) + + function callback1(_, iter, args...) + println("This is callback 1 for iter $iter") + end + + function callback2(_, iter, args...) + println("This is callback 2 for iter $iter") + return ("cb2", iter) + end + + function reduce_pulse(wrk, iter, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) + r0 = norm(ϵ⁽ⁱ⁾[1]) + r1 = norm(ϵ⁽ⁱ⁺¹⁾[1]) + scal!(0.8, ϵ⁽ⁱ⁺¹⁾[1]) + r2 = norm(ϵ⁽ⁱ⁺¹⁾[1]) + return (r0, r1, r2) + end + + problem = dummy_control_problem(; + N=2, + density=1.0, + complex_operators=false, + rng, + J_T=J_T_ss, + callback=callback1, + ) + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=Krotov, iter_stop=1) + end + @test contains( + captured.output, + "This is callback 1 for iter 0\n iter. J_T ∫gₐ(t)dt J ΔJ_T ΔJ secs" + ) + @test contains(captured.output, "This is callback 1 for iter 1\n 1") + + # passing `callback` to `optimize` overwrites `callback` in `problem` + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=Krotov, iter_stop=1, callback=callback2) + end + @test !contains(captured.output, "This is callback 1 for iter 0") + @test !contains(captured.output, "This is callback 1 for iter 1") + @test contains(captured.output, "This is callback 2 for iter 0") + @test contains(captured.output, "This is callback 2 for iter 1") + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=Krotov, + iter_stop=1, + callback=(callback1, callback2), + print_iters=false + ) + end + @test captured.value.converged + @test contains( + captured.output, + """ + This is callback 1 for iter 0 + This is callback 2 for iter 0 + This is callback 1 for iter 1 + This is callback 2 for iter 1 + """ + ) + @test captured.value.records == [("cb2", 0), ("cb2", 1)] + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=Krotov, + iter_stop=1, + callback=(callback1, callback2), + store_iter_info=["J_T"] + ) + end + @test captured.value.converged + @test length(captured.value.records) == 2 + @test length(captured.value.records[1]) == 3 + @test captured.value.records[1][1] == "cb2" + @test captured.value.records[1][2] == 0 + @test captured.value.records[1][3] isa Float64 + + # we should also be able to modify the updated pulses in the callback and + # have that take effect. + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=Krotov, + iter_stop=3, + callback=reduce_pulse, + store_iter_info=["iter.", "J_T"] + ) + end + @test captured.value.converged + for i = 2:length(captured.value.records) + record = captured.value.records[i] + (nrm_guess, nrm_upd, nrm_upd_scaled, iter, J_T) = record + nrm_upd_scaled_prev = captured.value.records[i-1][3] + @test nrm_upd_scaled ≈ 0.8 * nrm_upd + if i >= 3 + @test nrm_guess ≈ nrm_upd_scaled_prev + end + end + +end diff --git a/test/test_pulse_optimization.jl b/test/test_pulse_optimization.jl index 7161974..9855dd9 100644 --- a/test/test_pulse_optimization.jl +++ b/test/test_pulse_optimization.jl @@ -16,7 +16,7 @@ using Krotov # The problem occurs when the controls are actually pulses (on the # midpoints of the time grid), so that the optimization does not have to # call `discretize_on_midpoints` internally - problem = dummy_control_problem(; pulses_as_controls=true) + problem = dummy_control_problem(; pulses_as_controls=true, rng) nt = length(problem.tlist) guess_pulse = QuantumControl.Controls.get_controls(problem.trajectories)[1] @test length(guess_pulse) == nt - 1