Skip to content

Commit

Permalink
Replace n_traj with ntraj for QuTiP compatibility (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Sep 28, 2024
2 parents bdac5f4 + db5d485 commit 1913fd9
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 52 deletions.
4 changes: 2 additions & 2 deletions benchmarks/timeevolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function benchmark_timeevolution!(SUITE)
$ψ0,
$tlist,
$c_ops,
n_traj = 100,
ntraj = 100,
e_ops = $e_ops,
progress_bar = Val(false),
ensemble_method = EnsembleSerial(),
Expand All @@ -57,7 +57,7 @@ function benchmark_timeevolution!(SUITE)
$ψ0,
$tlist,
$c_ops,
n_traj = 100,
ntraj = 100,
e_ops = $e_ops,
progress_bar = Val(false),
ensemble_method = EnsembleThreads(),
Expand Down
2 changes: 1 addition & 1 deletion docs/src/users_guide/steadystate.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ exp_ss = real(expect(e_ops[1], ρ_ss))
tlist = LinRange(0, 50, 100)
# monte-carlo
sol_mc = mcsolve(H, ψ0, tlist, c_op_list, e_ops=e_ops, n_traj=100, progress_bar=false)
sol_mc = mcsolve(H, ψ0, tlist, c_op_list, e_ops=e_ops, ntraj=100, progress_bar=false)
exp_mc = real(sol_mc.expect[1, :])
# master eq.
Expand Down
18 changes: 9 additions & 9 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ end
e_ops::Union{Nothing,AbstractVector}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
n_traj::Int=1,
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
kwargs...)
Expand Down Expand Up @@ -452,7 +452,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `n_traj::Int`: Number of trajectories to use.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
- `prob_func::Function`: Function to use for generating the ODEProblem.
Expand Down Expand Up @@ -482,15 +482,15 @@ function mcsolve(
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
n_traj::Int = 1,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
if !isnothing(seeds) && length(seeds) != n_traj
throw(ArgumentError("Length of seeds must match n_traj ($n_traj), but got $(length(seeds))"))
if !isnothing(seeds) && length(seeds) != ntraj
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
end

ens_prob_mc = mcsolveEnsembleProblem(
Expand All @@ -509,16 +509,16 @@ function mcsolve(
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, n_traj = n_traj, ensemble_method = ensemble_method)
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
end

function mcsolve(
ens_prob_mc::EnsembleProblem;
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
n_traj::Int = 1,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
)
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = n_traj)
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
Expand All @@ -536,7 +536,7 @@ function mcsolve(
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionMCSol(
n_traj,
ntraj,
times,
states,
expvals,
Expand Down
14 changes: 7 additions & 7 deletions src/time_evolution/ssesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ end
e_ops::Union{Nothing,AbstractVector}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
n_traj::Int=1,
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
prob_func::Function=_mcsolve_prob_func,
output_func::Function=_mcsolve_output_func,
Expand Down Expand Up @@ -334,7 +334,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `n_traj::Int`: Number of trajectories to use.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `prob_func::Function`: Function to use for generating the SDEProblem.
- `output_func::Function`: Function to use for generating the output of a single trajectory.
Expand Down Expand Up @@ -362,7 +362,7 @@ function ssesolve(
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
n_traj::Int = 1,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
prob_func::Function = _ssesolve_prob_func,
output_func::Function = _ssesolve_output_func,
Expand All @@ -382,16 +382,16 @@ function ssesolve(
kwargs...,
)

return ssesolve(ens_prob; alg = alg, n_traj = n_traj, ensemble_method = ensemble_method)
return ssesolve(ens_prob; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
end

function ssesolve(
ens_prob::EnsembleProblem;
alg::StochasticDiffEqAlgorithm = SRA1(),
n_traj::Int = 1,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
)
sol = solve(ens_prob, alg, ensemble_method, trajectories = n_traj)
sol = solve(ens_prob, alg, ensemble_method, trajectories = ntraj)
_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
Expand All @@ -403,7 +403,7 @@ function ssesolve(
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionSSESol(
n_traj,
ntraj,
_sol_1.t,
states,
expvals,
Expand Down
12 changes: 6 additions & 6 deletions src/time_evolution/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ A structure storing the results and some information from solving quantum trajec
# Fields (Attributes)
- `n_traj::Int`: Number of trajectories
- `ntraj::Int`: Number of trajectories
- `times::AbstractVector`: The time list of the evolution in each trajectory.
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
Expand All @@ -70,7 +70,7 @@ struct TimeEvolutionMCSol{
TJT<:Vector{<:Vector{<:Real}},
TJW<:Vector{<:Vector{<:Integer}},
}
n_traj::Int
ntraj::Int
times::TT
states::TS
expect::TE
Expand All @@ -87,7 +87,7 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
print(io, "Solution of quantum trajectories\n")
print(io, "(converged: $(sol.converged))\n")
print(io, "--------------------------------\n")
print(io, "num_trajectories = $(sol.n_traj)\n")
print(io, "num_trajectories = $(sol.ntraj)\n")
print(io, "num_states = $(length(sol.states[1]))\n")
print(io, "num_expect = $(size(sol.expect, 1))\n")
print(io, "ODE alg.: $(sol.alg)\n")
Expand All @@ -100,7 +100,7 @@ end
struct TimeEvolutionSSESol
A structure storing the results and some information from solving trajectories of the Stochastic Shrodinger equation time evolution.
# Fields (Attributes)
- `n_traj::Int`: Number of trajectories
- `ntraj::Int`: Number of trajectories
- `times::AbstractVector`: The time list of the evolution in each trajectory.
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
Expand All @@ -118,7 +118,7 @@ struct TimeEvolutionSSESol{
T1<:Real,
T2<:Real,
}
n_traj::Int
ntraj::Int
times::TT
states::TS
expect::TE
Expand All @@ -133,7 +133,7 @@ function Base.show(io::IO, sol::TimeEvolutionSSESol)
print(io, "Solution of quantum trajectories\n")
print(io, "(converged: $(sol.converged))\n")
print(io, "--------------------------------\n")
print(io, "num_trajectories = $(sol.n_traj)\n")
print(io, "num_trajectories = $(sol.ntraj)\n")
print(io, "num_states = $(length(sol.states[1]))\n")
print(io, "num_expect = $(size(sol.expect, 1))\n")
print(io, "SDE alg.: $(sol.alg)\n")
Expand Down
6 changes: 3 additions & 3 deletions src/time_evolution/time_evolution_dynamical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ end
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
δα_list::Vector{<:Real}=fill(0.2, length(op_list)),
n_traj::Int=1,
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
jump_callback::LindbladJumpCallbackType=ContinuousLindbladJumpCallback(),
krylov_dim::Int=max(6, min(10, cld(length(ket2dm(ψ0).data), 4))),
Expand All @@ -712,7 +712,7 @@ function dsf_mcsolve(
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
δα_list::Vector{<:Real} = fill(0.2, length(op_list)),
n_traj::Int = 1,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
krylov_dim::Int = min(5, cld(length(ψ0.data), 3)),
Expand All @@ -736,5 +736,5 @@ function dsf_mcsolve(
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, n_traj = n_traj, ensemble_method = ensemble_method)
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
end
4 changes: 2 additions & 2 deletions test/core-test/dynamical-shifted-fock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
dsf_params,
e_ops = e_ops_dsf,
progress_bar = Val(false),
n_traj = 500,
ntraj = 500,
)
val_ss = abs2(sol0.expect[1, end])
@test sum(abs2.(sol0.expect[1, :] .- sol_dsf_me.expect[1, :])) / (val_ss * length(tlist)) < 0.1
Expand Down Expand Up @@ -140,7 +140,7 @@
dsf_params,
e_ops = e_ops_dsf2,
progress_bar = Val(false),
n_traj = 500,
ntraj = 500,
)

val_ss = abs2(sol0.expect[1, end])
Expand Down
36 changes: 14 additions & 22 deletions test/core-test/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
sol_me2 = mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
sol_me3 = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_mc_states = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, saveat = t_l, progress_bar = Val(false))
sol_sse = ssesolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_mc = mcsolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_mc_states = mcsolve(H, psi0, t_l, c_ops, ntraj = 500, saveat = t_l, progress_bar = Val(false))
sol_sse = ssesolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))

ρt_mc = [ket2dm.(normalize.(states)) for states in sol_mc_states.states]
expect_mc_states = mapreduce(states -> expect.(Ref(e_ops[1]), states), hcat, ρt_mc)
Expand Down Expand Up @@ -87,7 +87,7 @@
"Solution of quantum trajectories\n" *
"(converged: $(sol_mc.converged))\n" *
"--------------------------------\n" *
"num_trajectories = $(sol_mc.n_traj)\n" *
"num_trajectories = $(sol_mc.ntraj)\n" *
"num_states = $(length(sol_mc.states[1]))\n" *
"num_expect = $(size(sol_mc.expect, 1))\n" *
"ODE alg.: $(sol_mc.alg)\n" *
Expand All @@ -97,7 +97,7 @@
"Solution of quantum trajectories\n" *
"(converged: $(sol_sse.converged))\n" *
"--------------------------------\n" *
"num_trajectories = $(sol_sse.n_traj)\n" *
"num_trajectories = $(sol_sse.ntraj)\n" *
"num_states = $(length(sol_sse.states[1]))\n" *
"num_expect = $(size(sol_sse.expect, 1))\n" *
"SDE alg.: $(sol_sse.alg)\n" *
Expand All @@ -114,19 +114,11 @@
end

@testset "Type Inference mcsolve" begin
@inferred mcsolveEnsembleProblem(
H,
psi0,
t_l,
c_ops,
n_traj = 500,
e_ops = e_ops,
progress_bar = Val(false),
)
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
@inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false))
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false))
@inferred mcsolveEnsembleProblem(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, ntraj = 500, progress_bar = Val(true))
@inferred mcsolve(H, psi0, [0, 10], c_ops, ntraj = 500, progress_bar = Val(false))
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, ntraj = 500, progress_bar = Val(false))
end

@testset "Type Inference ssesolve" begin
Expand All @@ -135,12 +127,12 @@
psi0,
t_l,
c_ops,
n_traj = 500,
ntraj = 500,
e_ops = e_ops,
progress_bar = Val(false),
)
@inferred ssesolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred ssesolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, progress_bar = Val(true))
end
end

Expand Down Expand Up @@ -179,7 +171,7 @@
psi0 = kron(psi0_1, psi0_2)
t_l = LinRange(0, 20 / γ1, 1000)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false) # Here we don't put Val(false) because we want to test the support for Bool type
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = Val(false))
sol_mc = mcsolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = Val(false))
@test sum(abs.(sol_mc.expect[1:2, :] .- sol_me.expect[1:2, :])) / length(t_l) < 0.1
@test expect(sp1 * sm1, sol_me.states[end]) expect(sigmap() * sigmam(), ptrace(sol_me.states[end], 1))
end
Expand Down

0 comments on commit 1913fd9

Please sign in to comment.