From da37b688678b68c922ef3d957ea162422ddaa5ff Mon Sep 17 00:00:00 2001 From: Yi-Te Huang Date: Thu, 19 Sep 2024 13:01:03 +0800 Subject: [PATCH] fix type conversion of `tlist` in time evolution --- src/time_evolution/mcsolve.jl | 2 +- src/time_evolution/mesolve.jl | 2 +- src/time_evolution/sesolve.jl | 2 +- src/utilities.jl | 10 ++++++++++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index bd62bb41..a6f67f07 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -190,7 +190,7 @@ function mcsolveProblem( c_ops isa Nothing && throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead.")) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2 diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index f83a1bc3..363e0790 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -122,7 +122,7 @@ function mesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl ρ0 = mat2vec(ket2dm(ψ0).data) diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index 6ef530d6..a550e48a 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -103,7 +103,7 @@ function sesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl ϕ0 = get_data(ψ0) diff --git a/src/utilities.jl b/src/utilities.jl index df93872d..48762445 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -65,3 +65,13 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} = @warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" * join(arg, ", ") * ")` instead of `$argname = $arg`." maxlog = 1 + +# convert tlist in time evolution +_convert_tlist(::Int32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::Float32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::ComplexF32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::Int64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::Float64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::ComplexF64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::Val{32}, tlist::AbstractVector) = convert(Vector{Float32}, tlist) +_convert_tlist(::Val{64}, tlist::AbstractVector) = convert(Vector{Float64}, tlist) \ No newline at end of file