Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Oct 16, 2024
1 parent 3a7382e commit db80e1b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 39 deletions.
46 changes: 16 additions & 30 deletions src/qobj/quantum_object_evo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,9 @@ end

# Make the QuantumObjectEvolution, with the option to pre-multiply by a scalar
function QuantumObjectEvolution(op_func_list::Tuple, α = true)
op = _get_op_func_first(op_func_list)
op, data = _generate_data(op_func_list, α)
dims = op.dims
type = op.type
T = eltype(op)

data = _generate_data(T, op_func_list, α)

# Preallocate the SciMLOperator cache using a dense vector as a reference
v0 = sparse_to_dense(similar(op.data, size(op, 1)))
Expand All @@ -109,12 +106,14 @@ end
QuantumObjectEvolution(op::QuantumObject, α = true) =
QuantumObjectEvolution(MatrixOperator* op.data), op.type, op.dims)

@generated function _get_op_func_first(op_func_list::Tuple)
@generated function _generate_data(op_func_list::Tuple, α)
op_func_list_types = op_func_list.parameters
N = length(op_func_list_types)
T = ()

dims_expr = ()
first_op = nothing
data_expr = :(0)

for i in 1:N
op_func_type = op_func_list_types[i]
if op_func_type <: Tuple
Expand All @@ -126,8 +125,8 @@ QuantumObjectEvolution(op::QuantumObject, α = true) =
"The first element must be a Operator or SuperOperator, and the second element must be a function.",
),
)

data_type = op_type.parameters[1]
T = (T..., eltype(data_type))
dims_expr = (dims_expr..., :(op_func_list[$i][1].dims))
if i == 1
first_op = :(op_func_list[$i][1])
Expand All @@ -136,47 +135,34 @@ QuantumObjectEvolution(op::QuantumObject, α = true) =
op_type = op_func_type
(isoper(op_type) || issuper(op_type)) ||
throw(ArgumentError("The element must be a Operator or SuperOperator."))

data_type = op_type.parameters[1]
T = (T..., eltype(data_type))
dims_expr = (dims_expr..., :(op_func_list[$i].dims))

if i == 1
first_op = :(op_func_list[$i])
end
end
data_expr = :($data_expr + _make_SciMLOperator(op_func_list[$i], α))
end

length(unique(T)) == 1 || throw(ArgumentError("The types of the operators must be the same."))

quote
dims = tuple($(dims_expr...))

length(unique(dims)) == 1 || throw(ArgumentError("The dimensions of the operators must be the same."))

return $first_op
return $first_op, $data_expr
end
end

@generated function _generate_data(T, op_func_list::Tuple, α)
op_func_list_types = op_func_list.parameters
N = length(op_func_list_types)
data_expr = :(0)
for i in 1:N
op_func_type = op_func_list_types[i]
if op_func_type <: Tuple
data_expr = :(
$data_expr +
ScalarOperator(zero(T), op_func_list[$i][2]) * MatrixOperator* op_func_list[$i][1].data)
)
else
data_expr = :($data_expr + MatrixOperator* op_func_list[$i].data))
end
end

quote
return $data_expr
end
function _make_SciMLOperator(op_func::Tuple, α)
T = eltype(op_func[1])
update_func = (a, u, p, t) -> op_func[2](p, t)
return ScalarOperator(zero(T), update_func) * MatrixOperator* op_func[1].data)
end

_make_SciMLOperator(op::QuantumObject, α) = MatrixOperator* op.data)

function (QO::QuantumObjectEvolution)(p, t)
# We put 0 in the place of `u` because the time-dependence doesn't depend on the state
update_coefficients!(QO.data, 0, p, t)
Expand Down
11 changes: 2 additions & 9 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _save_func_sesolve(integrator)
internal_params = integrator.p
progr = internal_params.progr

if internal_params.is_empty_e_ops
if !internal_params.is_empty_e_ops
e_ops = internal_params.e_ops
expvals = internal_params.expvals

Expand All @@ -16,13 +16,6 @@ function _save_func_sesolve(integrator)
return u_modified!(integrator, false)
end

sesolve_ti_dudt!(du, u, p, t) = mul!(du, p.U, u)
function sesolve_td_dudt!(du, u, p, t)
mul!(du, p.U, u)
H_t = p.H_t(t, p)
return mul!(du, H_t, u, -1im, 1)
end

function _generate_sesolve_kwargs_with_callback(t_l, kwargs)
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
kwargs2 =
Expand Down Expand Up @@ -132,7 +125,7 @@ function sesolveProblem(
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)

tspan = (t_l[1], t_l[end])
return ODEProblem{true}(U, ϕ0, tspan, p; kwargs3...)
return ODEProblem{true,FullSpecialize}(U, ϕ0, tspan, p; kwargs3...)
end

@doc raw"""
Expand Down

0 comments on commit db80e1b

Please sign in to comment.