Skip to content

Commit

Permalink
fix type conversion of tlist in time evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang committed Sep 19, 2024
1 parent 93c3ff6 commit da37b68
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function mcsolveProblem(
c_ops isa Nothing &&
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2

Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function mesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

ρ0 = mat2vec(ket2dm(ψ0).data)

Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function sesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

ϕ0 = get_data(ψ0)

Expand Down
10 changes: 10 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,13 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1

# convert tlist in time evolution
_convert_tlist(::Int32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
_convert_tlist(::Float32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
_convert_tlist(::ComplexF32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist)
_convert_tlist(::Int64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
_convert_tlist(::Float64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
_convert_tlist(::ComplexF64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist)
_convert_tlist(::Val{32}, tlist::AbstractVector) = convert(Vector{Float32}, tlist)
_convert_tlist(::Val{64}, tlist::AbstractVector) = convert(Vector{Float64}, tlist)

0 comments on commit da37b68

Please sign in to comment.