diff --git a/ext/GRAPELBFGSBExt.jl b/ext/GRAPELBFGSBExt.jl index 5b6e3db..e4391f3 100644 --- a/ext/GRAPELBFGSBExt.jl +++ b/ext/GRAPELBFGSBExt.jl @@ -5,7 +5,7 @@ using GRAPE: GrapeWrk, update_result! import GRAPE: run_optimizer, gradient, step_width, search_direction -function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, info_hook, check_convergence!) +function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, callback, check_convergence!) m = get(wrk.kwargs, :lbfgsb_m, 10) factr = get(wrk.kwargs, :lbfgsb_factr, 1e7) @@ -87,19 +87,21 @@ function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, info_hook, check_co # x is the guess for the 0 iteration copyto!(wrk.gradient, obj.g) update_result!(wrk, 0) - #update_hook!(...) # TODO - info_tuple = info_hook(wrk, 0) + info_tuple = callback(wrk, 0) wrk.fg_count .= 0 - (info_tuple !== nothing) && push!(wrk.result.records, info_tuple) + if !(isnothing(info_tuple) || isempty(info_tuple)) + push!(wrk.result.records, info_tuple) + end end elseif obj.task[1:5] == b"NEW_X" # x is the optimized pulses for the current iteration iter = wrk.result.iter_start + obj.isave[30] update_result!(wrk, iter) - #update_hook!(...) # TODO - info_tuple = info_hook(wrk, wrk.result.iter) + info_tuple = callback(wrk, wrk.result.iter) wrk.fg_count .= 0 - (info_tuple !== nothing) && push!(wrk.result.records, info_tuple) + if !(isnothing(info_tuple) || isempty(info_tuple)) + push!(wrk.result.records, info_tuple) + end check_convergence!(wrk.result) if wrk.result.converged fill!(obj.task, Cuchar(' ')) diff --git a/ext/GRAPEOptimExt.jl b/ext/GRAPEOptimExt.jl index d7f164a..6c5170d 100644 --- a/ext/GRAPEOptimExt.jl +++ b/ext/GRAPEOptimExt.jl @@ -9,12 +9,12 @@ function run_optimizer( optimizer::Optim.AbstractOptimizer, wrk, fg!, - info_hook, + callback, check_convergence! ) tol_options = Optim.Options( - # just so we can instantiate `optimizer_state` before `callback` + # just so we can instantiate `optimizer_state` before `optim_callback` x_tol=get(wrk.kwargs, :x_tol, 0.0), f_tol=get(wrk.kwargs, :f_tol, 0.0), g_tol=get(wrk.kwargs, :g_tol, 1e-8), @@ -38,7 +38,7 @@ function run_optimizer( @assert isnan(wrk.optimizer_state.f_x_previous) # update the result object and check convergence - function callback(optimization_state::Optim.OptimizationState) + function optim_callback(optimization_state::Optim.OptimizationState) iter = wrk.result.iter_start + optimization_state.iteration #@assert optimization_state.value == objective.F #if optimization_state.iteration > 0 @@ -48,11 +48,10 @@ function run_optimizer( # ) < 1e-14 #end update_result!(wrk, iter) - #update_hook!(...) # TODO - info_tuple = info_hook(wrk, wrk.result.iter) + info_tuple = callback(wrk, wrk.result.iter) if hasproperty(objective, :DF) # DF is the *current* gradient, i.e., the gradient of the updated - # pulsevals, which (after the call to `info_hook`) is the gradient + # pulsevals, which (after the call to `callback`) is the gradient # for the the guess of the next iteration. wrk.gradient .= objective.DF elseif (optimization_state.iteration == 1) @@ -60,13 +59,15 @@ function run_optimizer( end copyto!(wrk.pulsevals_guess, wrk.pulsevals) wrk.fg_count .= 0 - (info_tuple !== nothing) && push!(wrk.result.records, info_tuple) + if !(isnothing(info_tuple) || isempty(info_tuple)) + push!(wrk.result.records, info_tuple) + end check_convergence!(wrk.result) return wrk.result.converged end options = Optim.Options( - callback=callback, + callback=optim_callback, iterations=wrk.result.iter_stop - wrk.result.iter_start, # TODO x_tol=get(wrk.kwargs, :x_tol, 0.0), f_tol=get(wrk.kwargs, :f_tol, 0.0), diff --git a/src/optimize.jl b/src/optimize.jl index 5c7b7d1..9a9aaa4 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -8,7 +8,7 @@ using QuantumControlBase: @threadsif using LinearAlgebra using Printf -import QuantumControlBase: optimize +import QuantumControlBase: optimize, make_print_iters @doc raw""" ```julia @@ -94,19 +94,25 @@ with explicit keyword arguments to `optimize`. - `:lower_bounds`: A vector of lower bound values. Values of `-Inf` indicate an unconstrained lower bound for that time interval, -* `update_hook`: Not implemented -* `info_hook`: A function (or tuple of functions) that receives the same - arguments as `update_hook`, in order to write information about the current - iteration to the screen or to a file. The default `info_hook` prints a table - with convergence information to the screen. Runs after `update_hook`. The - `info_hook` function may return a tuple, which is stored in the list of - `records` inside the [`GrapeResult`](@ref) object. +* `print_iters=true`: Whether to print information after each iteration. +* `store_iter_info=Set()`: Which fields from `print_iters` to store in + `result.records`. A subset of + `Set(["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"])`. +* `callback`: A function (or tuple of functions) that receives the + [GRAPE workspace](@ref GrapeWrk) and the iteration number. The function + may return a tuple of values which are stored in the + [`GrapeResult`](@ref) object `result.records`. The function can also mutate + the workspace, in particular the updated `pulsevals`. This may be used, + e.g., to apply a spectral filter to the updated pulses or to perform + similar manipulations. Note that `print_iters=true` (default) adds an + automatic callback to print information after each iteration. With + `store_iter_info`, that callback automatically stores a subset of the + printed information. * `check_convergence`: A function to check whether convergence has been reached. Receives a [`GrapeResult`](@ref) object `result`, and should set `result.converged` to `true` and `result.message` to an appropriate string in case of convergence. Multiple convergence checks can be performed by chaining - functions with `∘`. The convergence check is performed after any calls to - `update_hook` and `info_hook`. + functions with `∘`. The convergence check is performed after any `callback`. * `x_tol`: Parameter for Optim.jl * `f_tol`: Parameter for Optim.jl * `g_tol`: Parameter for Optim.jl @@ -159,11 +165,12 @@ optimize(problem, method::Val{:grape}) = optimize_grape(problem) See [`optimize(problem; method=GRAPE, kwargs...)`](@ref optimize(::Any, ::Val{:GRAPE})). """ function optimize_grape(problem) - update_hook! = get(problem.kwargs, :update_hook, (args...) -> nothing) - # TODO: implement update_hook - # TODO: streamline the interface for info_hook # TODO: check if x_tol, f_tol, g_tol are used necessary / used correctly - info_hook = get(problem.kwargs, :info_hook, print_table) + callback = get(problem.kwargs, :callback, (args...) -> nothing) + if haskey(problem.kwargs, :update_hook) || haskey(problem.kwargs, :info_hook) + msg = "The `update_hook` and `info_hook` arguments have been superseded by the `callback` argument" + throw(ArgumentError(msg)) + end check_convergence! = get(problem.kwargs, :check_convergence, res -> res) verbose = get(problem.kwargs, :verbose, false) gradient_method = get(problem.kwargs, :gradient_method, :gradgen) @@ -380,9 +387,9 @@ function optimize_grape(problem) end try if gradient_method == :gradgen - run_optimizer(optimizer, wrk, fg_gradgen!, info_hook, check_convergence!) + run_optimizer(optimizer, wrk, fg_gradgen!, callback, check_convergence!) elseif gradient_method == :taylor - run_optimizer(optimizer, wrk, fg_taylor!, info_hook, check_convergence!) + run_optimizer(optimizer, wrk, fg_taylor!, callback, check_convergence!) else error("Invalid gradient_method=$(repr(gradient_method)) ∉ (:gradgen, :taylor)") end @@ -441,64 +448,103 @@ function finalize_result!(wrk::GrapeWrk) end +make_print_iters(::Val{:GRAPE}; kwargs...) = make_grape_print_iters(; kwargs...) +make_print_iters(::Val{:grape}; kwargs...) = make_grape_print_iters(; kwargs...) + """Print optimization progress as a table. This functions serves as the default `info_hook` for an optimization with GRAPE. """ -function print_table(wrk, iteration, args...) - # TODO: make_print_table that precomputes headers and such, and maybe - # allows for more options. - # TODO: should we report ΔJ instead of ΔJ_T? - - J_T = wrk.result.J_T - ΔJ_T = J_T - wrk.result.J_T_prev - secs = wrk.result.secs +function make_grape_print_iters(; kwargs...) headers = ["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"] - if wrk.J_parts[2] ≠ 0.0 - headers = ["iter.", "J_T", "|∇J_T|", "|∇J_a|", "ΔJ_T", "FG(F)", "secs"] + store_iter_info = Set(get(kwargs, :store_iter_info, Set())) + info_vals = Vector{Any}(undef, length(headers)) + fill!(info_vals, nothing) + store_iter = false + store_J_T = false + store_grad_norm = false + store_ΔJ_T = false + store_counts = false + store_secs = false + for item in store_iter_info + if item == "iter." + store_iter = true + elseif item == "J_T" + store_J_T = true + elseif item == "|∇J_T|" + store_grad_norm = true + elseif item == "ΔJ_T" + store_ΔJ_T = true + elseif item == "FG(F)" + store_counts = true + elseif item == "secs" + store_secs = true + else + msg = "Item $(repr(item)) in `store_iter_info` is not one of $(repr(headers)))" + throw(ArgumentError(msg)) + end end - iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))" - width = Dict( - "iter." => max(length("$iter_stop"), 6), - "J_T" => 11, - "|∇J_T|" => 11, - "|∇J_a|" => 11, - "|∇J|" => 11, - "ΔJ" => 11, - "ΔJ_T" => 11, - "FG(F)" => 8, - "secs" => 8, - ) - - if iteration == 0 - for header in headers + function print_table(wrk, iteration, args...) + + J_T = wrk.result.J_T + ΔJ_T = J_T - wrk.result.J_T_prev + secs = wrk.result.secs + grad_norm = norm(wrk.grad_J_T) + counts = Tuple(wrk.fg_count) + + iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))" + width = Dict( + "iter." => max(length("$iter_stop"), 6), + "J_T" => 11, + "|∇J_T|" => 11, + "|∇J_a|" => 11, + "|∇J|" => 11, + "ΔJ" => 11, + "ΔJ_T" => 11, + "FG(F)" => 8, + "secs" => 8, + ) + + store_iter && (info_vals[1] = iteration) + store_J_T && (info_vals[2] = J_T) + store_grad_norm && (info_vals[3] = grad_norm) + store_ΔJ_T && (info_vals[4] = ΔJ_T) + store_counts && (info_vals[5] = counts) + store_secs && (info_vals[6] = secs) + + if iteration == 0 + for header in headers + w = width[header] + print(lpad(header, w)) + end + print("\n") + end + + strs = [ + "$iteration", + @sprintf("%.2e", J_T), + @sprintf("%.2e", grad_norm), + (iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a", + @sprintf("%d(%d)", counts[1], counts[2]), + @sprintf("%.1f", secs), + ] + for (str, header) in zip(strs, headers) w = width[header] - print(lpad(header, w)) + print(lpad(str, w)) end print("\n") - end + flush(stdout) + + return Tuple((value for value in info_vals if (value !== nothing))) - strs = [ - "$iteration", - @sprintf("%.2e", J_T), - @sprintf("%.2e", norm(wrk.grad_J_T)), - (iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a", - @sprintf("%d(%d)", wrk.fg_count[1], wrk.fg_count[2]), - @sprintf("%.1f", secs), - ] - if wrk.J_parts[2] ≠ 0.0 - insert!(strs, 4, @sprintf("%.2e", norm(wrk.grad_J_a))) - end - for (str, header) in zip(strs, headers) - w = width[header] - print(lpad(str, w)) end - print("\n") - flush(stdout) + + return print_table + end diff --git a/src/result.jl b/src/result.jl index 8648a47..12b35e7 100644 --- a/src/result.jl +++ b/src/result.jl @@ -16,7 +16,7 @@ The attributes of a `GrapeResult` object include discretized to the points of `tlist`) * optimized_controls: A vector of the optimized control fileds in the current iterations -* records: A vector of tuples with values returned by an `info_hook` routine +* records: A vector of tuples with values returned by a `callback` routine passed to [`optimize`](@ref) * converged: A boolean flag on whether the optimization is converged. This may be set to `true` by a `check_convergence` function. @@ -40,7 +40,7 @@ mutable struct GrapeResult{STST} states::Vector{STST} start_local_time::DateTime end_local_time::DateTime - records::Vector{Tuple} # storage for info_hook to write data into at each iteration + records::Vector{Tuple} # storage for callback to write data into at each iteration converged::Bool f_calls::Int64 fg_calls::Int64 diff --git a/src/workspace.jl b/src/workspace.jl index 49fbed6..576100a 100644 --- a/src/workspace.jl +++ b/src/workspace.jl @@ -5,9 +5,44 @@ using QuantumControlBase: Trajectory, get_control_derivs, init_prop_trajectory using QuantumGradientGenerators: GradVector, GradGenerator import LBFGSB -"""Grape Workspace. - -# Methods +"""GRAPE Workspace. + +The workspace is for internal use. However, it is also accessible in a +`callback` function. The callback may use or modify some of the following +attributes: + +* `trajectories`: a copy of the trajectories defining the control problem +* `adjoint_trajectories`: The `trajectories` with the adjoint generator +* `kwargs`: The keyword arguments from the [`ControlProblem`](@ref) or the + call to [`optimize`](@ref). +* `controls`: A tuple of the original controls (probably functions) +* `pulsevals_guess`: The combined vector of pulse values that are the guess in + the current iteration. Initially, the vector is the concatenation of + discretizing `controls` to the midpoints of the time grid. +* `pulsevals`: The combined vector of updated pulse values in the current + iteration. +* `gradient`: The total gradient for the guess in the current iteration +* `grad_J_T`: The current gradient for the final-time part of the functional. +* `grad_J_a`: The current gradient for the running cost part of the functional. +* `J_parts`: The two-component vector ``[J_T, J_a]`` +* `result`: The current result object +* `upper_bounds`: Upper bound for every `pulsevals`; `+Inf` indicates no bound. +* `lower_bounds`: Lower bound for every `pulsevals`; `-Inf` indicates no bound. +* `fg_count`: The total number of evaluations of the functional and evaluations + of the gradient, as a two-element vector. +* `optimizer`: The backend optimizer object +* `optimizer_state`: The internal state object of the `optimizer` (`nothing` if + the `optimizer` has no internal state) +* `result`: The current result object +* `tau_grads`: The gradients ∂τₖ/ϵₗ(tₙ) +* `fw_storage`: The storage of states for the forward propagation +* `fw_propagators`: The propagators used for the forward propagation +* `bw_propagators`: The propagators used for the backward propagation +* `use_threads`: Flag indicating whether the propagations are performed in + parallel. + +In addition, the following methods provide safer (non-mutating) access to +information in the workspace * [`step_width`](@ref) * [`search_direction`](@ref) @@ -16,55 +51,23 @@ import LBFGSB """ mutable struct GrapeWrk{O} - # a copy of the trajectories trajectories - - # the adjoint trajectories, containing the adjoint generators for the - # backward propagation adjoint_trajectories - # trajectories for bw-prop of gradients grad_trajectories - - # The kwargs from the control problem kwargs - - # Tuple of the original controls (probably functions) controls - pulsevals_guess::Vector{Float64} - pulsevals::Vector{Float64} - - # total gradient for guess in iterations gradient::Vector{Float64} - - # storage for current final time gradient grad_J_T::Vector{Float64} - - # storage for current running cost gradient grad_J_a::Vector{Float64} - - # two-component vector [J_T, J_a] J_parts::Vector{Float64} - - # Upper bound for every `pulsevals`, +Inf indicates no bound upper_bounds::Vector{Float64} - - # Upper bound for every `pulsevals`, -Inf indicates no bound lower_bounds::Vector{Float64} - fg_count::Vector{Int64} - - # map of controls to options - pulse_options - - # The optimizer optimizer::O - - # Internal optimizer state (`nothing` if `optimizer` has not state) optimizer_state - result ################################# @@ -72,37 +75,25 @@ mutable struct GrapeWrk{O} # backward-propagated states chi_states - - # gradients ∂τₖ/ϵₗ(tₙ) tau_grads::Vector{Matrix{ComplexF64}} - - # backward storage array fw_storage - - # for normal forward propagation fw_propagators - # for gradient backward propagation # gradient_method=:gradgen only bw_grad_propagators - # for normal backward propagation # gradient_method=:taylor only bw_propagators - # evaluated Hₖ for a particular point in time # gradient_method=:taylor only taylor_genops - # derivatives ∂Hₖ/∂ϵₗ(t) # gradient_method=:taylor only control_derivs - # 5 temporary states for each trajectory and each control, for evaluating # gradients via Taylor expansions # gradient_method=:taylor only taylor_grad_states - use_threads::Bool end @@ -173,7 +164,6 @@ function GrapeWrk(problem::QuantumControlBase.ControlProblem; verbose=false) lb .= options[:lower_bounds] end end - dummy_vals = IdDict(control => 1.0 for (i, control) in enumerate(controls)) fw_storage = [init_storage(traj.initial_state, tlist) for traj in trajectories] kwargs[:piecewise] = true # only accept piecewise propagators _prefixes = ["prop_", "fw_prop_"] @@ -264,7 +254,6 @@ function GrapeWrk(problem::QuantumControlBase.ControlProblem; verbose=false) upper_bounds, lower_bounds, fg_count, - pulse_options, optimizer, optimizer_state, result, @@ -297,8 +286,8 @@ end returns the scalar `α` so that `pulse_update(wrk) = α * search_direction(wrk)`, see [`pulse_update`](@ref) and [`search_direction`](@ref) for the iteration -desribed by the current [`GrapeWrk`](@ref) (for the state of `wrk` as available -in the `info_hook` of the current iteration. +described by the current [`GrapeWrk`](@ref) (for the state of `wrk` as available +in the `callback` of the current iteration. """ function step_width(wrk) u = pulse_update(wrk) @@ -362,7 +351,7 @@ end Δu = pulse_update(wrk) ``` -returns a vector conntaining the different between the optimized pulse values +returns a vector containing the different between the optimized pulse values and the guess pulse values of the current iteration. This should be proportional to [`search_direction`](@ref) with the proportionality factor [`step_width`](@ref). diff --git a/test/runtests.jl b/test/runtests.jl index bb9e1bc..e3fdc6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,5 +27,10 @@ unicodeplots() include("test_taylor_grad.jl") end + println("\n* Iterations (test_iterations.jl)") + @time @safetestset "Iterations" begin + include("test_iterations.jl") + end + end nothing diff --git a/test/test_iterations.jl b/test/test_iterations.jl new file mode 100644 index 0000000..3347e25 --- /dev/null +++ b/test/test_iterations.jl @@ -0,0 +1,142 @@ +using Test +using QuantumControl: optimize +using StableRNGs +using LinearAlgebra: norm +using LinearAlgebra.BLAS: scal! +using GRAPE +using QuantumControlTestUtils.DummyOptimization: dummy_control_problem +using QuantumControl.Functionals: J_T_ss +using IOCapture + +PASSTHROUGH = false + +@testset "iter_start_stop" begin + # Test that setting iter_start and iter_stop in fact restricts the + # optimization to those numbers + rng = StableRNG(1244568944) + problem = dummy_control_problem(; + iter_start=10, + N=2, + density=1.0, + complex_operators=false, + rng, + J_T=J_T_ss, + store_iter_info=["iter.", "J_T"] + ) + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=GRAPE, iter_stop=12) + end + res = captured.value + @test res.converged + @test res.iter_start == 10 + @test res.iter_stop == 12 + iters = [values[1] for values in res.records] + @test iters == [0, 11, 12] +end + + +@testset "callback" begin + + rng = StableRNG(1244568944) + + function callback1(_, iter, args...) + println("This is callback 1 for iter $iter") + end + + function callback2(_, iter, args...) + println("This is callback 2 for iter $iter") + return ("cb2", iter) + end + + function reduce_pulse(wrk, iter) + r0 = norm(wrk.pulsevals_guess) + r1 = norm(wrk.pulsevals) + scal!(0.8, wrk.pulsevals) + r2 = norm(wrk.pulsevals) + return (r0, r1, r2) + end + + problem = dummy_control_problem(; + N=2, + density=1.0, + complex_operators=false, + rng, + J_T=J_T_ss, + callback=callback1, + ) + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=GRAPE, iter_stop=1) + end + @test contains(captured.output, "This is callback 1 for iter 0\n iter. ") + @test contains(captured.output, "This is callback 1 for iter 1\n 1") + + # passing `callback` to `optimize` overwrites `callback` in `problem` + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize(problem; method=GRAPE, iter_stop=1, callback=callback2) + end + @test captured.value.converged + @test !contains(captured.output, "This is callback 1 for iter 0") + @test !contains(captured.output, "This is callback 1 for iter 1") + @test contains(captured.output, "This is callback 2 for iter 0") + @test contains(captured.output, "This is callback 2 for iter 1") + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=GRAPE, + iter_stop=1, + callback=(callback1, callback2), + print_iters=false + ) + end + @test captured.value.converged + @test contains( + captured.output, + """ + This is callback 1 for iter 0 + This is callback 2 for iter 0 + This is callback 1 for iter 1 + This is callback 2 for iter 1 + """ + ) + @test captured.value.records == [("cb2", 0), ("cb2", 1)] + + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=GRAPE, + iter_stop=1, + callback=(callback1, callback2), + store_iter_info=["J_T"] + ) + end + @test captured.value.converged + @test length(captured.value.records) == 2 + @test length(captured.value.records[1]) == 3 + @test captured.value.records[1][1] == "cb2" + @test captured.value.records[1][2] == 0 + @test captured.value.records[1][3] isa Float64 + + # we should also be able to modify the updated pulses in the callback and + # have that take effect. + captured = IOCapture.capture(passthrough=PASSTHROUGH) do + optimize( + problem; + method=GRAPE, + iter_stop=3, + callback=reduce_pulse, + store_iter_info=["iter.", "J_T"] + ) + end + for i = 2:length(captured.value.records) + record = captured.value.records[i] + (nrm_guess, nrm_upd, nrm_upd_scaled, iter, J_T) = record + nrm_upd_scaled_prev = captured.value.records[i-1][3] + @test nrm_upd_scaled ≈ 0.8 * nrm_upd + if i >= 3 + @test nrm_guess ≈ nrm_upd_scaled_prev + end + end + +end diff --git a/test/test_tls_optimization.jl b/test/test_tls_optimization.jl index 0bb455e..25f12f6 100644 --- a/test/test_tls_optimization.jl +++ b/test/test_tls_optimization.jl @@ -51,12 +51,6 @@ function ls_info_hook(wrk, iter) end -function J_T_info_hook(wrk, iter, args...) - J_T = wrk.result.J_T - return (J_T,) -end - - function print_ls_table(res) println("") @printf("%6s", "iter") @@ -95,7 +89,7 @@ end check_convergence=res -> begin ((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰")) end, - info_hook=(ls_info_hook, GRAPE.print_table,) + callback=ls_info_hook, ) res = optimize(problem; method=GRAPE) print_ls_table(res) @@ -125,7 +119,7 @@ end check_convergence=res -> begin ((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰")) end, - info_hook=(ls_info_hook, GRAPE.print_table,) + callback=ls_info_hook, ) res = optimize(problem; method=GRAPE) print_ls_table(res) @@ -151,7 +145,7 @@ end iter_stop=5, prop_method=ExpProp, J_T=J_T_sm, - info_hook=(ls_info_hook, GRAPE.print_table,), + callback=ls_info_hook, lbfgsb_iprint=100, ) res = optimize(problem; method=GRAPE) @@ -186,7 +180,7 @@ end check_convergence=res -> begin ((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰")) end, - info_hook=(ls_info_hook, GRAPE.print_table,) + callback=ls_info_hook, ) res = optimize(problem; method=GRAPE) print_ls_table(res) @@ -218,7 +212,7 @@ end check_convergence=res -> begin ((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰")) end, - info_hook=(ls_info_hook, GRAPE.print_table,) + callback=ls_info_hook, ) res = optimize(problem; method=GRAPE) print_ls_table(res) @@ -248,12 +242,8 @@ end end, ) res_krotov = optimize(problem; method=Krotov, lambda_a=100.0, iter_stop=2) - res = optimize( - problem; - method=GRAPE, - continue_from=res_krotov, - info_hook=(J_T_info_hook, GRAPE.print_table,) - ) + res = + optimize(problem; method=GRAPE, continue_from=res_krotov, store_iter_info=["J_T"],) display(res) @test res.J_T < 1e-3 @test abs(res.records[1][1] - res_krotov.J_T) < 1e-14 @@ -285,7 +275,7 @@ end method=Krotov, continue_from=res_grape, lambda_a=1.0, - info_hook=(J_T_info_hook, Krotov.print_table,) + store_iter_info=["J_T"], ) display(res) @test res.J_T < 1e-3