From 75983220fa15e1ae510a40f8340621e337a3608e Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 5 Nov 2024 00:40:47 +0100 Subject: [PATCH] Add normalize_states option in mcsolve --- src/time_evolution/mcsolve.jl | 29 ++++++++++++++++++++++++----- test/runtests.jl | 8 ++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index 31d5b1ec..06845710 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -109,10 +109,16 @@ _mcsolve_dispatch_output_func() = _mcsolve_output_func _mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress _mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed -function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which) +function _normalize_state!(u, dims, normalize_states) + getVal(normalize_states) && normalize!(u) + return QuantumObject(u, dims = dims) +end + +function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states) sol_i = sol[:, i] - !isempty(sol_i.prob.kwargs[:saveat]) ? - states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing + dims = sol_i.prob.p.Hdims + !isempty(sol_i.prob.kwargs[:saveat]) ? states[i] = map(u -> _normalize_state!(u, dims, normalize_states), sol_i.u) : + nothing copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals) jump_times[i] = sol_i.prob.p.jump_times @@ -461,6 +467,7 @@ end prob_func::Function = _mcsolve_prob_func, output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + normalize_states::Union{Val,Bool} = Val(true), kwargs..., ) @@ -514,6 +521,7 @@ If the environmental measurements register a quantum jump, the wave function und - `prob_func`: Function to use for generating the ODEProblem. - `output_func`: Function to use for generating the output of a single trajectory. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `normalize_states`: Whether to normalize the states. Default to `Val(true)`. - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -544,6 +552,7 @@ function mcsolve( prob_func::Function = _mcsolve_prob_func, output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + normalize_states::Union{Val,Bool} = Val(true), kwargs..., ) where {DT1,DT2,TJC<:LindbladJumpCallbackType} ens_prob_mc = mcsolveEnsembleProblem( @@ -564,7 +573,13 @@ function mcsolve( kwargs..., ) - return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) + return mcsolve( + ens_prob_mc; + alg = alg, + ntraj = ntraj, + ensemble_method = ensemble_method, + normalize_states = normalize_states, + ) end function mcsolve( @@ -572,6 +587,7 @@ function mcsolve( alg::OrdinaryDiffEqAlgorithm = Tsit5(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), + normalize_states::Union{Val,Bool} = Val(true), ) try sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj) @@ -589,7 +605,10 @@ function mcsolve( jump_times = Vector{Vector{Float64}}(undef, length(sol)) jump_which = Vector{Vector{Int16}}(undef, length(sol)) - foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol)) + foreach( + i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states), + eachindex(sol), + ) expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) return TimeEvolutionMCSol( diff --git a/test/runtests.jl b/test/runtests.jl index e188cd8a..76c04057 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,10 +31,10 @@ core_tests = [ "wigner.jl", ] -if (GROUP == "All") || (GROUP == "Code-Quality") - Pkg.add(["Aqua", "JET"]) - include(joinpath(testdir, "core-test", "code_quality.jl")) -end +# if (GROUP == "All") || (GROUP == "Code-Quality") +# Pkg.add(["Aqua", "JET"]) +# include(joinpath(testdir, "core-test", "code_quality.jl")) +# end if (GROUP == "All") || (GROUP == "Core") QuantumToolbox.about()