Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine info-hook/update-hook into callback #71

Merged
merged 1 commit into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions ext/GRAPELBFGSBExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using GRAPE: GrapeWrk, update_result!
import GRAPE: run_optimizer, gradient, step_width, search_direction


function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, info_hook, check_convergence!)
function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, callback, check_convergence!)

m = get(wrk.kwargs, :lbfgsb_m, 10)
factr = get(wrk.kwargs, :lbfgsb_factr, 1e7)
Expand Down Expand Up @@ -87,19 +87,21 @@ function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, info_hook, check_co
# x is the guess for the 0 iteration
copyto!(wrk.gradient, obj.g)
update_result!(wrk, 0)
#update_hook!(...) # TODO
info_tuple = info_hook(wrk, 0)
info_tuple = callback(wrk, 0)
wrk.fg_count .= 0
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
if !(isnothing(info_tuple) || isempty(info_tuple))
push!(wrk.result.records, info_tuple)
end
end
elseif obj.task[1:5] == b"NEW_X"
# x is the optimized pulses for the current iteration
iter = wrk.result.iter_start + obj.isave[30]
update_result!(wrk, iter)
#update_hook!(...) # TODO
info_tuple = info_hook(wrk, wrk.result.iter)
info_tuple = callback(wrk, wrk.result.iter)
wrk.fg_count .= 0
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
if !(isnothing(info_tuple) || isempty(info_tuple))
push!(wrk.result.records, info_tuple)
end
check_convergence!(wrk.result)
if wrk.result.converged
fill!(obj.task, Cuchar(' '))
Expand Down
17 changes: 9 additions & 8 deletions ext/GRAPEOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ function run_optimizer(
optimizer::Optim.AbstractOptimizer,
wrk,
fg!,
info_hook,
callback,
check_convergence!
)

tol_options = Optim.Options(
# just so we can instantiate `optimizer_state` before `callback`
# just so we can instantiate `optimizer_state` before `optim_callback`
x_tol=get(wrk.kwargs, :x_tol, 0.0),
f_tol=get(wrk.kwargs, :f_tol, 0.0),
g_tol=get(wrk.kwargs, :g_tol, 1e-8),
Expand All @@ -38,7 +38,7 @@ function run_optimizer(
@assert isnan(wrk.optimizer_state.f_x_previous)

# update the result object and check convergence
function callback(optimization_state::Optim.OptimizationState)
function optim_callback(optimization_state::Optim.OptimizationState)
iter = wrk.result.iter_start + optimization_state.iteration
#@assert optimization_state.value == objective.F
#if optimization_state.iteration > 0
Expand All @@ -48,25 +48,26 @@ function run_optimizer(
# ) < 1e-14
#end
update_result!(wrk, iter)
#update_hook!(...) # TODO
info_tuple = info_hook(wrk, wrk.result.iter)
info_tuple = callback(wrk, wrk.result.iter)
if hasproperty(objective, :DF)
# DF is the *current* gradient, i.e., the gradient of the updated
# pulsevals, which (after the call to `info_hook`) is the gradient
# pulsevals, which (after the call to `callback`) is the gradient
# for the the guess of the next iteration.
wrk.gradient .= objective.DF
elseif (optimization_state.iteration == 1)
@error "Cannot determine guess gradient"
end
copyto!(wrk.pulsevals_guess, wrk.pulsevals)
wrk.fg_count .= 0
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
if !(isnothing(info_tuple) || isempty(info_tuple))
push!(wrk.result.records, info_tuple)
end
check_convergence!(wrk.result)
return wrk.result.converged
end

options = Optim.Options(
callback=callback,
callback=optim_callback,
iterations=wrk.result.iter_stop - wrk.result.iter_start, # TODO
x_tol=get(wrk.kwargs, :x_tol, 0.0),
f_tol=get(wrk.kwargs, :f_tol, 0.0),
Expand Down
164 changes: 105 additions & 59 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using LinearAlgebra
using Printf

import QuantumControlBase: optimize
import QuantumControlBase: optimize, make_print_iters

@doc raw"""
```julia
Expand Down Expand Up @@ -94,19 +94,25 @@
- `:lower_bounds`: A vector of lower bound values. Values of `-Inf` indicate
an unconstrained lower bound for that time interval,

* `update_hook`: Not implemented
* `info_hook`: A function (or tuple of functions) that receives the same
arguments as `update_hook`, in order to write information about the current
iteration to the screen or to a file. The default `info_hook` prints a table
with convergence information to the screen. Runs after `update_hook`. The
`info_hook` function may return a tuple, which is stored in the list of
`records` inside the [`GrapeResult`](@ref) object.
* `print_iters=true`: Whether to print information after each iteration.
* `store_iter_info=Set()`: Which fields from `print_iters` to store in
`result.records`. A subset of
`Set(["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"])`.
* `callback`: A function (or tuple of functions) that receives the
[GRAPE workspace](@ref GrapeWrk) and the iteration number. The function
may return a tuple of values which are stored in the
[`GrapeResult`](@ref) object `result.records`. The function can also mutate
the workspace, in particular the updated `pulsevals`. This may be used,
e.g., to apply a spectral filter to the updated pulses or to perform
similar manipulations. Note that `print_iters=true` (default) adds an
automatic callback to print information after each iteration. With
`store_iter_info`, that callback automatically stores a subset of the
printed information.
* `check_convergence`: A function to check whether convergence has been
reached. Receives a [`GrapeResult`](@ref) object `result`, and should set
`result.converged` to `true` and `result.message` to an appropriate string in
case of convergence. Multiple convergence checks can be performed by chaining
functions with `∘`. The convergence check is performed after any calls to
`update_hook` and `info_hook`.
functions with `∘`. The convergence check is performed after any `callback`.
* `x_tol`: Parameter for Optim.jl
* `f_tol`: Parameter for Optim.jl
* `g_tol`: Parameter for Optim.jl
Expand Down Expand Up @@ -159,11 +165,12 @@
See [`optimize(problem; method=GRAPE, kwargs...)`](@ref optimize(::Any, ::Val{:GRAPE})).
"""
function optimize_grape(problem)
update_hook! = get(problem.kwargs, :update_hook, (args...) -> nothing)
# TODO: implement update_hook
# TODO: streamline the interface for info_hook
# TODO: check if x_tol, f_tol, g_tol are used necessary / used correctly
info_hook = get(problem.kwargs, :info_hook, print_table)
callback = get(problem.kwargs, :callback, (args...) -> nothing)
if haskey(problem.kwargs, :update_hook) || haskey(problem.kwargs, :info_hook)
msg = "The `update_hook` and `info_hook` arguments have been superseded by the `callback` argument"
throw(ArgumentError(msg))

Check warning on line 172 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L171-L172

Added lines #L171 - L172 were not covered by tests
end
check_convergence! = get(problem.kwargs, :check_convergence, res -> res)
verbose = get(problem.kwargs, :verbose, false)
gradient_method = get(problem.kwargs, :gradient_method, :gradgen)
Expand Down Expand Up @@ -380,9 +387,9 @@
end
try
if gradient_method == :gradgen
run_optimizer(optimizer, wrk, fg_gradgen!, info_hook, check_convergence!)
run_optimizer(optimizer, wrk, fg_gradgen!, callback, check_convergence!)
elseif gradient_method == :taylor
run_optimizer(optimizer, wrk, fg_taylor!, info_hook, check_convergence!)
run_optimizer(optimizer, wrk, fg_taylor!, callback, check_convergence!)
else
error("Invalid gradient_method=$(repr(gradient_method)) ∉ (:gradgen, :taylor)")
end
Expand Down Expand Up @@ -441,64 +448,103 @@
end


make_print_iters(::Val{:GRAPE}; kwargs...) = make_grape_print_iters(; kwargs...)
make_print_iters(::Val{:grape}; kwargs...) = make_grape_print_iters(; kwargs...)

Check warning on line 452 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L452

Added line #L452 was not covered by tests


"""Print optimization progress as a table.

This functions serves as the default `info_hook` for an optimization with
GRAPE.
"""
function print_table(wrk, iteration, args...)
# TODO: make_print_table that precomputes headers and such, and maybe
# allows for more options.
# TODO: should we report ΔJ instead of ΔJ_T?

J_T = wrk.result.J_T
ΔJ_T = J_T - wrk.result.J_T_prev
secs = wrk.result.secs
function make_grape_print_iters(; kwargs...)

headers = ["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"]
if wrk.J_parts[2] ≠ 0.0
headers = ["iter.", "J_T", "|∇J_T|", "|∇J_a|", "ΔJ_T", "FG(F)", "secs"]
store_iter_info = Set(get(kwargs, :store_iter_info, Set()))
info_vals = Vector{Any}(undef, length(headers))
fill!(info_vals, nothing)
store_iter = false
store_J_T = false
store_grad_norm = false
store_ΔJ_T = false
store_counts = false
store_secs = false
for item in store_iter_info
if item == "iter."
store_iter = true
elseif item == "J_T"
store_J_T = true
elseif item == "|∇J_T|"
store_grad_norm = true
elseif item == "ΔJ_T"
store_ΔJ_T = true
elseif item == "FG(F)"
store_counts = true
elseif item == "secs"
store_secs = true

Check warning on line 484 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L477-L484

Added lines #L477 - L484 were not covered by tests
else
msg = "Item $(repr(item)) in `store_iter_info` is not one of $(repr(headers)))"
throw(ArgumentError(msg))

Check warning on line 487 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L486-L487

Added lines #L486 - L487 were not covered by tests
end
end

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
width = Dict(
"iter." => max(length("$iter_stop"), 6),
"J_T" => 11,
"|∇J_T|" => 11,
"|∇J_a|" => 11,
"|∇J|" => 11,
"ΔJ" => 11,
"ΔJ_T" => 11,
"FG(F)" => 8,
"secs" => 8,
)

if iteration == 0
for header in headers
function print_table(wrk, iteration, args...)

J_T = wrk.result.J_T
ΔJ_T = J_T - wrk.result.J_T_prev
secs = wrk.result.secs
grad_norm = norm(wrk.grad_J_T)
counts = Tuple(wrk.fg_count)

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
width = Dict(
"iter." => max(length("$iter_stop"), 6),
"J_T" => 11,
"|∇J_T|" => 11,
"|∇J_a|" => 11,
"|∇J|" => 11,
"ΔJ" => 11,
"ΔJ_T" => 11,
"FG(F)" => 8,
"secs" => 8,
)

store_iter && (info_vals[1] = iteration)
store_J_T && (info_vals[2] = J_T)
store_grad_norm && (info_vals[3] = grad_norm)
store_ΔJ_T && (info_vals[4] = ΔJ_T)
store_counts && (info_vals[5] = counts)
store_secs && (info_vals[6] = secs)

if iteration == 0
for header in headers
w = width[header]
print(lpad(header, w))
end
print("\n")
end

strs = [
"$iteration",
@sprintf("%.2e", J_T),
@sprintf("%.2e", grad_norm),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
@sprintf("%d(%d)", counts[1], counts[2]),
@sprintf("%.1f", secs),
]
for (str, header) in zip(strs, headers)
w = width[header]
print(lpad(header, w))
print(lpad(str, w))
end
print("\n")
end
flush(stdout)

return Tuple((value for value in info_vals if (value !== nothing)))

strs = [
"$iteration",
@sprintf("%.2e", J_T),
@sprintf("%.2e", norm(wrk.grad_J_T)),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
@sprintf("%d(%d)", wrk.fg_count[1], wrk.fg_count[2]),
@sprintf("%.1f", secs),
]
if wrk.J_parts[2] ≠ 0.0
insert!(strs, 4, @sprintf("%.2e", norm(wrk.grad_J_a)))
end
for (str, header) in zip(strs, headers)
w = width[header]
print(lpad(str, w))
end
print("\n")
flush(stdout)

return print_table

end


Expand Down
4 changes: 2 additions & 2 deletions src/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The attributes of a `GrapeResult` object include
discretized to the points of `tlist`)
* optimized_controls: A vector of the optimized control fileds in the current
iterations
* records: A vector of tuples with values returned by an `info_hook` routine
* records: A vector of tuples with values returned by a `callback` routine
passed to [`optimize`](@ref)
* converged: A boolean flag on whether the optimization is converged. This
may be set to `true` by a `check_convergence` function.
Expand All @@ -40,7 +40,7 @@ mutable struct GrapeResult{STST}
states::Vector{STST}
start_local_time::DateTime
end_local_time::DateTime
records::Vector{Tuple} # storage for info_hook to write data into at each iteration
records::Vector{Tuple} # storage for callback to write data into at each iteration
converged::Bool
f_calls::Int64
fg_calls::Int64
Expand Down
Loading
Loading