From a8ddddf640d65b0817170757b3d538d7e574dc2c Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Wed, 17 Jul 2024 09:44:58 -0400 Subject: [PATCH] Make `chi` functions not-in-place MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `chi!` functions previously used by GRAPE and Krotov are now simply `chi` do not act in-place. This is more general and easier to implement for the user, as it allows to use immutable structs for states Note that in extreme performance-critical situations, one could still construct the χ-states in-place via a closure or functor. Both chi and J_T can now have an optional keyword argument `tau` (instead of the previous improperly implemented and unicode `τ`). Whether or not `tau` should be passed to these functions is automatically detected. --- src/optimize.jl | 89 +++++++++++++++++++++++++++--------------------- src/result.jl | 17 +++++---- src/workspace.jl | 20 ++++++++++- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index a6eb3c3..4d0ea67 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,8 +1,10 @@ using QuantumControlBase.QuantumPropagators.Generators: Operator using QuantumControlBase.QuantumPropagators.Controls: discretize, evaluate using QuantumControlBase.QuantumPropagators: prop_step!, reinit_prop!, propagate -using QuantumControlBase.QuantumPropagators.Storage: write_to_storage!, get_from_storage! +using QuantumControlBase.QuantumPropagators.Storage: + write_to_storage!, get_from_storage!, get_from_storage using QuantumControlBase: make_chi, set_atexit_save_optimization +using QuantumControlBase: taus! using QuantumControlBase: @threadsif, Trajectory using LinearAlgebra using Printf @@ -24,9 +26,12 @@ with explicit keyword arguments to `optimize`. # Required problem keyword arguments -* `J_T`: A function `J_T(ϕ, trajectories)` that evaluates the final time - functional from a list `ϕ` of forward-propagated states and - `problem.trajectories`. +* `J_T`: A function `J_T(Ψ, trajectories)` that evaluates the final time + functional from a list `Ψ` of forward-propagated states and + `problem.trajectories`. The function `J_T` may also take a keyword argument + `tau`. If it does, a vector containing the complex overlaps of the target + states (`target_state` property of each trajectory in `problem.trajectories`) + with the propagated states will be passed to `J_T`. # Recommended problem keyword arguments @@ -53,10 +58,12 @@ The following keyword arguments are supported (with default values): This overrides the global `lambda_a` and `update_shape` arguments. -* `chi`: A function `chi!(χ, ϕ, trajectories)` that receives a list `ϕ` - of the forward propagated states and must set ``|χₖ⟩ = -∂J_T/∂⟨ϕₖ|``. If not - given, it will be automatically determined from `J_T` via [`make_chi`](@ref - QuantumControlBase.make_chi) with the default parameters. +* `chi`: A function `chi(Ψ, trajectories)` that receives a list `Ψ` + of the forward propagated states and returns a vector of states + ``|χₖ⟩ = -∂J_T/∂⟨Ψₖ|``. If not given, it will be automatically determined + from `J_T` via [`make_chi`](@ref) with the default parameters. Similarly to + `J_T`, if `chi` accepts a keyword argument `tau`, it will be passed a vector + of complex overlaps. * `sigma=nothing`: A function that calculates the second-order contribution. If not given, the first-order Krotov method is used. * `iter_start=0`: The initial iteration number. @@ -132,14 +139,6 @@ function optimize_krotov(problem) ϵ⁽ⁱ⁾ = wrk.pulses0 ϵ⁽ⁱ⁺¹⁾ = wrk.pulses1 - if haskey(wrk.kwargs, :chi) - chi! = wrk.kwargs[:chi] - else - # we only want to evaluate `make_chi` if `chi` is not a kwarg - J_T_func = wrk.kwargs[:J_T] - chi! = make_chi(J_T_func, wrk.trajectories) - end - if skip_initial_forward_propagation @info "Skipping initial forward propagation" else @@ -170,7 +169,7 @@ function optimize_krotov(problem) try while !wrk.result.converged i = i + 1 - krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) + krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾) update_result!(wrk, i) info_tuple = callback(wrk, i, ϵ⁽ⁱ⁺¹⁾, ϵ⁽ⁱ⁾) if !(isnothing(info_tuple) || isempty(info_tuple)) @@ -208,18 +207,18 @@ function transform_control_ranges(c, ϵ_min, ϵ_max, check) end -function krotov_initial_fw_prop!(ϵ⁽⁰⁾, ϕₖⁱⁿ, k, wrk) +function krotov_initial_fw_prop!(ϵ⁽⁰⁾, ϕₖ, k, wrk) for propagator in wrk.fw_propagators propagator.parameters = IdDict(zip(wrk.controls, ϵ⁽⁰⁾)) end - reinit_prop!(wrk.fw_propagators[k], ϕₖⁱⁿ; transform_control_ranges) + reinit_prop!(wrk.fw_propagators[k], ϕₖ; transform_control_ranges) Φ₀ = wrk.fw_storage[k] - (Φ₀ !== nothing) && write_to_storage!(Φ₀, 1, ϕₖⁱⁿ) + (Φ₀ !== nothing) && write_to_storage!(Φ₀, 1, ϕₖ) N_T = length(wrk.result.tlist) - 1 for n = 1:N_T - ϕₖ = prop_step!(wrk.fw_propagators[k]) - (Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, ϕₖ) + Ψₖ = prop_step!(wrk.fw_propagators[k]) + (Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, Ψₖ) end # TODO: allow a custom prop_step! routine end @@ -236,12 +235,8 @@ _eval_mu(μ::Operator, _...) = μ _eval_mu(μ::AbstractMatrix, _...) = μ -function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) +function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾) - χ = [ - (ismutable(propagator.state) ? propagator.state : similar(propagator.state)) for - propagator in wrk.bw_propagators - ] tlist = wrk.result.tlist N_T = length(tlist) - 1 N = length(wrk.trajectories) @@ -250,14 +245,22 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) Φ = wrk.fw_storage # TODO: pass in Φ₁, Φ₀ as parameters ∫gₐdt = wrk.g_a_int Im = imag + chi = wrk.kwargs[:chi] # guaranteed to exist in `KrotovWrk` constructor + guess_parameters = IdDict(zip(wrk.controls, ϵ⁽ⁱ⁾)) updated_parameters = IdDict(zip(wrk.controls, ϵ⁽ⁱ⁺¹⁾)) # backward propagation - ϕ = [propagator.state for propagator in wrk.fw_propagators] - chi!(χ, ϕ, wrk.trajectories) + + Ψ = [propagator.state for propagator in wrk.fw_propagators] + if wrk.chi_takes_tau + χ = chi(Ψ, wrk.trajectories; tau=wrk.result.tau_vals) + else + χ = chi(Ψ, wrk.trajectories) + end @threadsif wrk.use_threads for k = 1:N + # TODO: normalize χ; warn if norm is close to zero wrk.bw_propagators[k].parameters = guess_parameters reinit_prop!(wrk.bw_propagators[k], χ[k]; transform_control_ranges) write_to_storage!(X[k], N_T + 1, χ[k]) @@ -271,26 +274,30 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) @threadsif wrk.use_threads for k = 1:N wrk.fw_propagators[k].parameters = updated_parameters - local ϕₖ = wrk.trajectories[k].initial_state - reinit_prop!(wrk.fw_propagators[k], ϕₖ; transform_control_ranges) + local Ψₖ = wrk.trajectories[k].initial_state + reinit_prop!(wrk.fw_propagators[k], Ψₖ; transform_control_ranges) end ∫gₐdt .= 0.0 for n = 1:N_T # `n` is the index for the time interval dt = tlist[n+1] - tlist[n] for k = 1:N - get_from_storage!(χ[k], X[k], n) + if ismutable(χ[k]) + get_from_storage!(χ[k], X[k], n) + else + χ[k] = get_from_storage(X[k], n) + end end ϵₙ⁽ⁱ⁺¹⁾ = [ϵ⁽ⁱ⁾[l][n] for l ∈ 1:L] # ϵₙ⁽ⁱ⁺¹⁾ ≈ ϵₙ⁽ⁱ⁾ for non-linear controls # TODO: we could add a self-consistent loop here for ϵₙ⁽ⁱ⁺¹⁾ Δuₙ = zeros(L) # Δu is Δϵ without (Sₗₙ/λₐ) factor for l = 1:L # `l` is the index for the different controls for k = 1:N # k is the index over the trajectories - ϕₖ = wrk.fw_propagators[k].state + Ψₖ = wrk.fw_propagators[k].state μₖₗ = wrk.control_derivs[k][l] if !isnothing(μₖₗ) μₗₖₙ = _eval_mu(μₖₗ, wrk, ϵₙ⁽ⁱ⁺¹⁾, tlist, n) - Δuₙ[l] += Im(dot(χ[k], μₗₖₙ, ϕₖ)) + Δuₙ[l] += Im(dot(χ[k], μₗₖₙ, Ψₖ)) end end end @@ -305,8 +312,8 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾, chi!) end # TODO: end of self-consistent loop @threadsif wrk.use_threads for k = 1:N - local ϕₖ = prop_step!(wrk.fw_propagators[k]) - write_to_storage!(Φ[k], n, ϕₖ) + local Ψₖ = prop_step!(wrk.fw_propagators[k]) + write_to_storage!(Φ[k], n, Ψₖ) end # TODO: update sigma end # time loop @@ -315,12 +322,17 @@ end function update_result!(wrk::KrotovWrk, i::Int64) res = wrk.result - J_T_func = wrk.kwargs[:J_T] + J_T = wrk.kwargs[:J_T] res.J_T_prev = res.J_T for k in eachindex(wrk.fw_propagators) res.states[k] = wrk.fw_propagators[k].state end - res.J_T = J_T_func(res.states, wrk.trajectories) + taus!(res.tau_vals, res.states, wrk.trajectories; ignore_missing_target_state=true) + if wrk.J_T_takes_tau + res.J_T = J_T(res.states, wrk.trajectories; tau=res.tau_vals) + else + res.J_T = J_T(res.states, wrk.trajectories) + end (i > 0) && (res.iter = i) if i >= res.iter_stop res.converged = true @@ -331,7 +343,6 @@ function update_result!(wrk::KrotovWrk, i::Int64) prev_time = res.end_local_time res.end_local_time = now() res.secs = Dates.toms(res.end_local_time - prev_time) / 1000.0 - # TODO: calculate τ values end diff --git a/src/result.jl b/src/result.jl index ab92844..1287a7e 100644 --- a/src/result.jl +++ b/src/result.jl @@ -12,16 +12,19 @@ The attributes of a `KrotovResult` object include * `iter`: The number of the current iteration * `J_T`: The value of the final-time functional in the current iteration * `J_T_prev`: The value of the final-time functional in the previous iteration -* `tlist`: The time grid on which the control are discetized. +* `tlist`: The time grid on which the control are discretized. * `guess_controls`: A vector of the original control fields (each field discretized to the points of `tlist`) -* optimized_controls: A vector of the optimized control fileds in the current - iterations -* records: A vector of tuples with values returned by a `callback` routine +* `optimized_controls`: A vector of the optimized control fields. Calculated only + at the end of the optimization, not after each iteration. +* `tau_vals`: For any trajectory that defines a `target_state`, the complex + overlap of that target state with the propagated state. For any trajectory + for which the `target_state` is `nothing`, the value is zero. +* `records`: A vector of tuples with values returned by a `callback` routine passed to [`optimize`](@ref) -* converged: A boolean flag on whether the optimization is converged. This +* `converged`: A boolean flag on whether the optimization is converged. This may be set to `true` by a `check_convergence` function. -* message: A message string to explain the reason for convergence. This may be +* `message`: A message string to explain the reason for convergence. This may be set by a `check_convergence` function. All of the above attributes may be referenced in a `check_convergence` function @@ -53,7 +56,7 @@ function KrotovResult(problem) iter_stop = get(problem.kwargs, :iter_stop, 5000) iter = iter_start secs = 0 - tau_vals = Vector{ComplexF64}() + tau_vals = zeros(ComplexF64, length(problem.trajectories)) guess_controls = [discretize(control, tlist) for control in controls] J_T = 0.0 J_T_prev = 0.0 diff --git a/src/workspace.jl b/src/workspace.jl index fff90a3..c6b650d 100644 --- a/src/workspace.jl +++ b/src/workspace.jl @@ -41,6 +41,8 @@ mutable struct KrotovWrk g_a_int::Vector{Float64} update_shapes::Vector{Vector{Float64}} lambda_vals::Vector{Float64} + J_T_takes_tau::Bool # Does J_T have a tau keyword arg? + chi_takes_tau::Bool # Does chi have a tau keyword arg? # map of controls to options result @@ -150,6 +152,21 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false) kwargs... ) for (k, traj) in enumerate(adjoint_trajectories) ] + J_T_takes_tau = false + if haskey(kwargs, :J_T) + J_T = kwargs[:J_T] + else + msg = "`optimize` for `method=Krotov` must be passed the functional `J_T`." + throw(ArgumentError(msg)) + end + J_T_takes_tau = + hasmethod(J_T, Tuple{typeof(result.states),typeof(trajectories)}, (:tau,)) + if !haskey(kwargs, :chi) + kwargs[:chi] = make_chi(J_T, trajectories) + end + chi = kwargs[:chi] + chi_takes_tau = + hasmethod(chi, Tuple{typeof(result.states),typeof(trajectories)}, (:tau,)) return KrotovWrk( trajectories, adjoint_trajectories, @@ -160,7 +177,8 @@ function KrotovWrk(problem::QuantumControlBase.ControlProblem; verbose=false) g_a_int, update_shapes, lambda_vals, - # pulse_options, # XXX + J_T_takes_tau, + chi_takes_tau, result, control_derivs, fw_storage,