Skip to content

Commit

Permalink
Add normalize_states option in mcsolve (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio authored Nov 5, 2024
1 parent 2adcd8a commit 9b23c6f
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,16 @@ _mcsolve_dispatch_output_func() = _mcsolve_output_func
_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress
_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed

function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
function _normalize_state!(u, dims, normalize_states)
getVal(normalize_states) && normalize!(u)
return QuantumObject(u, dims = dims)
end

function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states)
sol_i = sol[:, i]
!isempty(sol_i.prob.kwargs[:saveat]) ?
states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing
dims = sol_i.prob.p.Hdims
!isempty(sol_i.prob.kwargs[:saveat]) ? states[i] = map(u -> _normalize_state!(u, dims, normalize_states), sol_i.u) :
nothing

copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
jump_times[i] = sol_i.prob.p.jump_times
Expand Down Expand Up @@ -461,6 +467,7 @@ end
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
normalize_states::Union{Val,Bool} = Val(true),
kwargs...,
)
Expand Down Expand Up @@ -514,6 +521,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `prob_func`: Function to use for generating the ODEProblem.
- `output_func`: Function to use for generating the output of a single trajectory.
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `normalize_states`: Whether to normalize the states. Default to `Val(true)`.
- `kwargs`: The keyword arguments for the ODEProblem.
# Notes
Expand Down Expand Up @@ -544,6 +552,7 @@ function mcsolve(
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
normalize_states::Union{Val,Bool} = Val(true),
kwargs...,
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
ens_prob_mc = mcsolveEnsembleProblem(
Expand All @@ -564,14 +573,21 @@ function mcsolve(
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
return mcsolve(
ens_prob_mc;
alg = alg,
ntraj = ntraj,
ensemble_method = ensemble_method,
normalize_states = normalize_states,
)
end

function mcsolve(
ens_prob_mc::EnsembleProblem;
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
normalize_states::Union{Val,Bool} = Val(true),
)
try
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
Expand All @@ -589,7 +605,10 @@ function mcsolve(
jump_times = Vector{Vector{Float64}}(undef, length(sol))
jump_which = Vector{Vector{Int16}}(undef, length(sol))

foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
foreach(
i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states),
eachindex(sol),
)
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionMCSol(
Expand Down

0 comments on commit 9b23c6f

Please sign in to comment.