diff --git a/src/optimize.jl b/src/optimize.jl index e1f65e0..c7803ff 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -7,6 +7,7 @@ using QuantumControl.QuantumPropagators.Storage: using QuantumControl.Functionals: make_chi, taus! using QuantumControl: set_atexit_save_optimization using QuantumControl: @threadsif, Trajectory +using QuantumControl.QuantumPropagators: _StoreState using LinearAlgebra using Printf @@ -219,9 +220,13 @@ function krotov_initial_fw_prop!(ϵ⁽⁰⁾, ϕₖ, k, wrk) N_T = length(wrk.result.tlist) - 1 for n = 1:N_T Ψₖ = prop_step!(wrk.fw_propagators[k]) + if haskey(wrk.fw_prop_kwargs[k], :callback) + local cb = wrk.fw_prop_kwargs[k][:callback] + local observables = get(wrk.fw_prop_kwargs[k], :observables, _StoreState()) + cb(wrk.fw_propagators[k], observables) + end (Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, Ψₖ) end - # TODO: allow a custom prop_step! routine end @@ -267,6 +272,11 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾) write_to_storage!(X[k], N_T + 1, χ[k]) for n = N_T:-1:1 local χₖ = prop_step!(wrk.bw_propagators[k]) + if haskey(wrk.bw_prop_kwargs[k], :callback) + local cb = wrk.bw_prop_kwargs[k][:callback] + local observables = get(wrk.bw_prop_kwargs[k], :observables, _StoreState()) + cb(wrk.bw_propagators[k], observables) + end write_to_storage!(X[k], n, χₖ) end end @@ -314,6 +324,11 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾) # TODO: end of self-consistent loop @threadsif wrk.use_threads for k = 1:N local Ψₖ = prop_step!(wrk.fw_propagators[k]) + if haskey(wrk.fw_prop_kwargs[k], :callback) + local cb = wrk.fw_prop_kwargs[k][:callback] + local observables = get(wrk.fw_prop_kwargs[k], :observables, _StoreState()) + cb(wrk.fw_propagators[k], observables) + end write_to_storage!(Φ[k], n, Ψₖ) end # TODO: update sigma diff --git a/src/workspace.jl b/src/workspace.jl index f91c0d3..6954038 100644 --- a/src/workspace.jl +++ b/src/workspace.jl @@ -47,9 +47,11 @@ mutable struct KrotovWrk result ################################# - # scratch objects, per trajectory: + # Per trajectory: control_derivs + fw_prop_kwargs::Vector{Dict{Symbol,Any}} + bw_prop_kwargs::Vector{Dict{Symbol,Any}} fw_storage # forward storage array (per trajectory) fw_storage2 # forward storage array (per trajectory) bw_storage # backward storage array (per trajectory) @@ -63,6 +65,7 @@ end function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false) use_threads = get(problem.kwargs, :use_threads, false) trajectories = [traj for traj in problem.trajectories] + N = length(trajectories) adjoint_trajectories = [adjoint(traj) for traj in problem.trajectories] controls = get_controls(trajectories) if length(controls) == 0 @@ -128,6 +131,8 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false) bw_storage = [init_storage(traj.initial_state, tlist) for traj in trajectories] kwargs[:piecewise] = true # only accept piecewise propagators _prefixes = ["prop_", "fw_prop_"] + fw_prop_kwargs = [Dict{Symbol,Any}() for _ = 1:N] + bw_prop_kwargs = [Dict{Symbol,Any}() for _ = 1:N] fw_propagators = [ init_prop_trajectory( traj, @@ -136,6 +141,7 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false) _msg="Initializing fw-prop of trajectory $k", _prefixes, _filter_kwargs=true, + _kwargs_dict=fw_prop_kwargs[k], kwargs... ) for (k, traj) in enumerate(trajectories) ] @@ -149,6 +155,7 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false) _prefixes, _filter_kwargs=true, bw_prop_backward=true, # will filter to `backward=true` + _kwargs_dict=bw_prop_kwargs[k], kwargs... ) for (k, traj) in enumerate(adjoint_trajectories) ] @@ -181,6 +188,8 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false) chi_takes_tau, result, control_derivs, + fw_prop_kwargs, + bw_prop_kwargs, fw_storage, fw_storage2, bw_storage,