Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for prop_callback #51

Merged
merged 1 commit into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using QuantumControl.Functionals: make_chi, taus!
using QuantumControl: set_atexit_save_optimization
using QuantumControl: @threadsif, Trajectory
using QuantumControl.QuantumPropagators: _StoreState
using LinearAlgebra
using Printf

Expand Down Expand Up @@ -219,9 +220,13 @@
N_T = length(wrk.result.tlist) - 1
for n = 1:N_T
Ψₖ = prop_step!(wrk.fw_propagators[k])
if haskey(wrk.fw_prop_kwargs[k], :callback)
local cb = wrk.fw_prop_kwargs[k][:callback]
local observables = get(wrk.fw_prop_kwargs[k], :observables, _StoreState())
cb(wrk.fw_propagators[k], observables)

Check warning on line 226 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L224-L226

Added lines #L224 - L226 were not covered by tests
end
(Φ₀ !== nothing) && write_to_storage!(Φ₀, n + 1, Ψₖ)
end
# TODO: allow a custom prop_step! routine
end


Expand Down Expand Up @@ -267,6 +272,11 @@
write_to_storage!(X[k], N_T + 1, χ[k])
for n = N_T:-1:1
local χₖ = prop_step!(wrk.bw_propagators[k])
if haskey(wrk.bw_prop_kwargs[k], :callback)
local cb = wrk.bw_prop_kwargs[k][:callback]
local observables = get(wrk.bw_prop_kwargs[k], :observables, _StoreState())
cb(wrk.bw_propagators[k], observables)

Check warning on line 278 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L276-L278

Added lines #L276 - L278 were not covered by tests
end
write_to_storage!(X[k], n, χₖ)
end
end
Expand Down Expand Up @@ -314,6 +324,11 @@
# TODO: end of self-consistent loop
@threadsif wrk.use_threads for k = 1:N
local Ψₖ = prop_step!(wrk.fw_propagators[k])
if haskey(wrk.fw_prop_kwargs[k], :callback)
local cb = wrk.fw_prop_kwargs[k][:callback]
local observables = get(wrk.fw_prop_kwargs[k], :observables, _StoreState())
cb(wrk.fw_propagators[k], observables)

Check warning on line 330 in src/optimize.jl

View check run for this annotation

Codecov / codecov/patch

src/optimize.jl#L328-L330

Added lines #L328 - L330 were not covered by tests
end
write_to_storage!(Φ[k], n, Ψₖ)
end
# TODO: update sigma
Expand Down
11 changes: 10 additions & 1 deletion src/workspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ mutable struct KrotovWrk
result

#################################
# scratch objects, per trajectory:
# Per trajectory:

control_derivs
fw_prop_kwargs::Vector{Dict{Symbol,Any}}
bw_prop_kwargs::Vector{Dict{Symbol,Any}}
fw_storage # forward storage array (per trajectory)
fw_storage2 # forward storage array (per trajectory)
bw_storage # backward storage array (per trajectory)
Expand All @@ -63,6 +65,7 @@ end
function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false)
use_threads = get(problem.kwargs, :use_threads, false)
trajectories = [traj for traj in problem.trajectories]
N = length(trajectories)
adjoint_trajectories = [adjoint(traj) for traj in problem.trajectories]
controls = get_controls(trajectories)
if length(controls) == 0
Expand Down Expand Up @@ -128,6 +131,8 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false)
bw_storage = [init_storage(traj.initial_state, tlist) for traj in trajectories]
kwargs[:piecewise] = true # only accept piecewise propagators
_prefixes = ["prop_", "fw_prop_"]
fw_prop_kwargs = [Dict{Symbol,Any}() for _ = 1:N]
bw_prop_kwargs = [Dict{Symbol,Any}() for _ = 1:N]
fw_propagators = [
init_prop_trajectory(
traj,
Expand All @@ -136,6 +141,7 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false)
_msg="Initializing fw-prop of trajectory $k",
_prefixes,
_filter_kwargs=true,
_kwargs_dict=fw_prop_kwargs[k],
kwargs...
) for (k, traj) in enumerate(trajectories)
]
Expand All @@ -149,6 +155,7 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false)
_prefixes,
_filter_kwargs=true,
bw_prop_backward=true, # will filter to `backward=true`
_kwargs_dict=bw_prop_kwargs[k],
kwargs...
) for (k, traj) in enumerate(adjoint_trajectories)
]
Expand Down Expand Up @@ -181,6 +188,8 @@ function KrotovWrk(problem::QuantumControl.ControlProblem; verbose=false)
chi_takes_tau,
result,
control_derivs,
fw_prop_kwargs,
bw_prop_kwargs,
fw_storage,
fw_storage2,
bw_storage,
Expand Down