Skip to content

Commit

Permalink
Properly define the in-place vs not-in-place interface
Browse files Browse the repository at this point in the history
We support both mutating and non-mutating propagators. Mutation is
better for large Hilbert spaces. Non-mutation is better for small
Hilbert spaces (`StaticArrays`!) or when trying to use automatic
differentiation.

There are some subtleties in finding the correct abstraction. It is not
as simple as using the built-in `ismutable` for states or operators and
making decisions based on that: Anytime we use custom structs, unless
that struct is explicitly defined as `mutable`, it is considered
immutable. However, we can still use in-place propagation, mutating the
mutable *components* of that struct.

Instead of overloading `ismutable`, we define the in-place or
not-in-place interface explicitly via the required behavior guaranteed
by the `check_state`, `check_generator`, and `check_operator` functions.

A new `QuantumPropagators.Interfaces.supports_inplace` function is
available to check whether a given `state` or `operator` type is
suitable for in-place operations.
  • Loading branch information
goerz committed Jul 21, 2024
1 parent d76fa79 commit 5023b66
Show file tree
Hide file tree
Showing 18 changed files with 504 additions and 335 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
QuantumPropagatorsODEExt = "OrdinaryDiffEq"
QuantumPropagatorsRecursiveArrayToolsExt = "RecursiveArrayTools"
QuantumPropagatorsStaticArraysExt = "StaticArrays"

[compat]
OffsetArrays = "1"
Expand Down
5 changes: 3 additions & 2 deletions ext/QuantumPropagatorsODEExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using QuantumPropagators:
ode_function
using QuantumPropagators.Controls:
get_controls, get_parameters, evaluate, evaluate!, discretize_on_midpoints
using QuantumPropagators.Interfaces: supports_inplace
import QuantumPropagators: init_prop, reinit_prop!, prop_step!, set_state!, set_t!


Expand All @@ -24,7 +25,7 @@ ode_propagator = init_prop(
generator,
tlist;
method=OrdinaryDiffEq, # or: `method=DifferentialEquations`
inplace=true,
inplace=QuantumPropagators.Interfaces.supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down Expand Up @@ -92,7 +93,7 @@ function init_prop(
generator,
tlist,
method::Val{:OrdinaryDiffEq};
inplace=true,
inplace=supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down
9 changes: 9 additions & 0 deletions ext/QuantumPropagatorsStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module QuantumPropagatorsStaticArraysExt

import QuantumPropagators.Interfaces: supports_inplace
using StaticArrays: SArray, MArray

supports_inplace(::SArray) = false
supports_inplace(::MArray) = true

end
35 changes: 19 additions & 16 deletions src/QuantumPropagators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,33 @@ include("propagator.jl")
export init_prop, reinit_prop!, prop_step!
# not exported: set_t!, set_state!

include("pwc_utils.jl")
include("cheby_propagator.jl")
include("newton_propagator.jl")
include("exp_propagator.jl")

include("ode_function.jl")

#! format: off
module Interfaces
export supports_inplace
export check_operator, check_state, check_tlist, check_amplitude
export check_control, check_generator, check_propagator
export check_parameterized_function, check_parameterized
include(joinpath("interfaces", "utils.jl"))
include(joinpath("interfaces", "state.jl"))
include(joinpath("interfaces", "tlist.jl"))
include(joinpath("interfaces", "operator.jl"))
include(joinpath("interfaces", "amplitude.jl"))
include(joinpath("interfaces", "control.jl"))
include(joinpath("interfaces", "generator.jl"))
include(joinpath("interfaces", "propagator.jl"))
include(joinpath("interfaces", "parameterization.jl"))
include("interfaces/supports_inplace.jl")
include("interfaces/utils.jl")
include("interfaces/state.jl")
include("interfaces/tlist.jl")
include("interfaces/operator.jl")
include("interfaces/amplitude.jl")
include("interfaces/control.jl")
include("interfaces/generator.jl")
include("interfaces/propagator.jl")
include("interfaces/parameterization.jl")
end
#! format: on

include("pwc_utils.jl")
include("cheby_propagator.jl")
include("newton_propagator.jl")
include("exp_propagator.jl")

include("ode_function.jl")


include("timings.jl")

# high-level interface
Expand Down
11 changes: 8 additions & 3 deletions src/cheby_propagator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .Controls: get_controls, evaluate, discretize
using .Interfaces: supports_inplace
using TimerOutputs: reset_timer!, @timeit_debug, TimerOutput

"""Propagator for Chebychev propagation (`method=QuantumPropagators.Cheby`).
Expand Down Expand Up @@ -38,7 +39,7 @@ cheby_propagator = init_prop(
generator,
tlist;
method=Cheby,
inplace=true,
inplace=QuantumPropagators.Interfaces.supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down Expand Up @@ -88,7 +89,7 @@ function init_prop(
generator,
tlist,
method::Val{:Cheby};
inplace=true,
inplace=supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down Expand Up @@ -356,7 +357,11 @@ function prop_step!(propagator::ChebyPropagator)
tlist = getfield(propagator, :tlist)
(0 < n < length(tlist)) || return nothing
if propagator.inplace
H = _pwc_set_genop!(propagator, n)
if supports_inplace(H)
H = _pwc_set_genop!(propagator, n)
else
H = _pwc_get_genop(propagator, n)
end
Cheby.cheby!(
Ψ,
H,
Expand Down
14 changes: 10 additions & 4 deletions src/exp_propagator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .Controls: get_controls
using .Interfaces: supports_inplace
using TimerOutputs: reset_timer!, @timeit_debug, TimerOutput

"""Propagator for propagation via direct exponentiation
Expand Down Expand Up @@ -46,7 +47,7 @@ exp_propagator = init_prop(
generator,
tlist;
method=ExpProp,
inplace=true,
inplace=QuantumPropagators.Interfaces.supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down Expand Up @@ -82,7 +83,7 @@ function init_prop(
generator,
tlist,
method::Val{:ExpProp};
inplace=true,
inplace=supports_inplace(state),
backward=false,
verbose=false,
parameters=nothing,
Expand Down Expand Up @@ -142,6 +143,7 @@ init_prop(state, generator, tlist, method::Val{:expprop}; kwargs...) =

function prop_step!(propagator::ExpPropagator)
@timeit_debug propagator.timing_data "prop_step!" begin
H = propagator.genop
n = propagator.n
tlist = getfield(propagator, :tlist)
(0 < n < length(tlist)) || return nothing
Expand All @@ -151,8 +153,12 @@ function prop_step!(propagator::ExpPropagator)
end
Ψ = convert(propagator.convert_state, propagator.state)
if propagator.inplace
_pwc_set_genop!(propagator, n)
H = convert(propagator.convert_operator, propagator.genop)
if supports_inplace(propagator.genop)
_pwc_set_genop!(propagator, n)
H = convert(propagator.convert_operator, propagator.genop)
else
H = convert(propagator.convert_operator, _pwc_get_genop(propagator, n))
end
ExpProp.expprop!(Ψ, H, dt, propagator.wrk; func=propagator.func)
if Ψ propagator.state # `convert` of Ψ may have been a no-op
copyto!(propagator.state, convert(typeof(propagator.state), Ψ))
Expand Down
1 change: 0 additions & 1 deletion src/expprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ module ExpProp
export ExpPropWrk, expprop!

using LinearAlgebra
import StaticArrays
using TimerOutputs: @timeit_debug, TimerOutput


Expand Down
6 changes: 5 additions & 1 deletion src/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ constant `Number`. If the number of coefficients is less than the
number of operators, the first `ops` are considered to have ``c_l = 1``.
An `Operator` object would generally not be instantiated directly, but be
obtained from a (@ref) via [`evaluate`](@ref).
obtained from a [`Generator`](@ref) via [`evaluate`](@ref).
The ``Ĥ_l`` in the sum are considered immutable. This implies that an
`Operator` can be updated in-place with [`evaluate!`](@ref) by only changing
the `coeffs`.
"""
struct Operator{OT,CT<:Number}

Expand Down
35 changes: 11 additions & 24 deletions src/interfaces/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using ..Generators: Generator
```julia
@test check_generator(
generator; state, tlist,
for_mutable_operator=true, for_immutable_operator=true,
for_mutable_state=true, for_immutable_state=true,
for_pwc=true, for_time_continuous=false,
for_expval=true, for_parameterization=false,
atol=1e-14, quiet=false)
Expand All @@ -26,19 +24,20 @@ verifies the given `generator`:
If `for_pwc` (default):
* [`evaluate(generator, tlist, n)`](@ref evaluate) must return a valid
* [`op = evaluate(generator, tlist, n)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* If `for_mutable_operator`,
* If `QuantumPropagators.Interfaces.supports_inplace(op)` is `true`,
[`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined
If `for_time_continuous`:
* [`evaluate(generator, t)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* If `for_mutable_operator`, [`evaluate!(op, generator, t)`](@ref evaluate!)
must be defined
* If `QuantumPropagators.Interfaces.supports_inplace(op)` is `true`,
[`evaluate!(op, generator, t)`](@ref evaluate!) must be defined
If `for_parameterization` (may require the `RecursiveArrayTools` package to be
loaded):
Expand All @@ -55,10 +54,6 @@ function check_generator(
generator;
state,
tlist,
for_mutable_operator=true,
for_immutable_operator=true,
for_mutable_state=true,
for_immutable_state=true,
for_expval=true,
for_pwc=true,
for_time_continuous=false,
Expand All @@ -69,7 +64,7 @@ function check_generator(
_check_amplitudes=true # undocumented (internal use)
)

@assert check_state(state; for_mutable_state, for_immutable_state, atol, quiet=true)
@assert check_state(state; atol, quiet=true)
@assert tlist isa Vector{Float64}
@assert length(tlist) >= 2

Expand Down Expand Up @@ -170,8 +165,6 @@ function check_generator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
Expand All @@ -189,14 +182,13 @@ function check_generator(
success = false
end

op = evaluate(generator, tlist, 1)

try
op = evaluate(generator, tlist, 1; vals_dict)
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
Expand All @@ -214,9 +206,8 @@ function check_generator(
success = false
end

if for_mutable_operator
if success && supports_inplace(op)
try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1)
catch exc
quiet || @error(
Expand Down Expand Up @@ -247,8 +238,6 @@ function check_generator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
Expand All @@ -272,8 +261,6 @@ function check_generator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
Expand All @@ -291,9 +278,9 @@ function check_generator(
success = false
end

if for_mutable_operator
op = evaluate(generator, tlist[begin])
if success && supports_inplace(op)
try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end])
catch exc
quiet || @error(
Expand Down
Loading

0 comments on commit 5023b66

Please sign in to comment.