diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index bd62bb41..a6f67f07 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -190,7 +190,7 @@ function mcsolveProblem( c_ops isa Nothing && throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead.")) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2 diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index f83a1bc3..363e0790 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -122,7 +122,7 @@ function mesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl ρ0 = mat2vec(ket2dm(ψ0).data) diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index 6ef530d6..a550e48a 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -103,7 +103,7 @@ function sesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = _convert_tlist(eltype(H), tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl ϕ0 = get_data(ψ0) diff --git a/src/utilities.jl b/src/utilities.jl index df93872d..48762445 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -65,3 +65,13 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} = @warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" * join(arg, ", ") * ")` instead of `$argname = $arg`." maxlog = 1 + +# convert tlist in time evolution +_convert_tlist(::Int32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::Float32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::ComplexF32, tlist::AbstractVector) = _convert_tlist(Val(32), tlist) +_convert_tlist(::Int64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::Float64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::ComplexF64, tlist::AbstractVector) = _convert_tlist(Val(64), tlist) +_convert_tlist(::Val{32}, tlist::AbstractVector) = convert(Vector{Float32}, tlist) +_convert_tlist(::Val{64}, tlist::AbstractVector) = convert(Vector{Float64}, tlist) \ No newline at end of file