Skip to content

Commit

Permalink
Adapt to in-place interface
Browse files Browse the repository at this point in the history
Test optimization with static arrays
  • Loading branch information
goerz committed Jul 26, 2024
1 parent a8ddddf commit 75b766b
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ jobs:
version: '1'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/cache@v2
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v2
- name: "Instantiate test environment"
run: |
wget https://raw.githubusercontent.com/JuliaQuantumControl/JuliaQuantumControl/master/scripts/installorg.jl
Expand All @@ -48,10 +48,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/cache@v2
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: julia-actions/cache@v2
- name: "Install Python dependencies"
run: |
set -x
Expand Down
3 changes: 2 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using QuantumControlBase.QuantumPropagators.Generators: Operator
using QuantumControlBase.QuantumPropagators.Controls: discretize, evaluate
using QuantumControlBase.QuantumPropagators.Interfaces: supports_inplace
using QuantumControlBase.QuantumPropagators: prop_step!, reinit_prop!, propagate
using QuantumControlBase.QuantumPropagators.Storage:
write_to_storage!, get_from_storage!, get_from_storage
Expand Down Expand Up @@ -282,7 +283,7 @@ function krotov_iteration(wrk, ϵ⁽ⁱ⁾, ϵ⁽ⁱ⁺¹⁾)
for n = 1:N_T # `n` is the index for the time interval
dt = tlist[n+1] - tlist[n]
for k = 1:N
if ismutable(χ[k])
if supports_inplace(χ[k])
get_from_storage!(χ[k], X[k], n)
else
χ[k] = get_from_storage(X[k], n)
Expand Down
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
GRAPE = "6b52fcaf-80fe-489a-93e9-9f92080510be"
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Krotov = "b05dcdc7-62f6-4360-bf2c-0898bba419de"
Expand All @@ -20,17 +21,19 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuantumControl = "8a270532-f23f-47a8-83a9-b33d10cad486"
QuantumControlBase = "f10a33bc-5a64-497c-be7b-6f86b4f0c2aa"
QuantumControlTestUtils = "d3fd27c9-1dfb-4e67-b0c0-90d0d87a1e48"
QuantumGradientGenerators = "a563f35e-61db-434d-8c01-8b9e3ccdfd85"
QuantumPropagators = "7bf12567-5742-4b91-a078-644e72a65fc1"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TwoQubitWeylChamber = "cad078a0-0012-46f4-b55e-a945d44e115b"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Documenter = "1.1"
julia = "1.6"
julia = "1.9"
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ unicodeplots()
# Note: comment outer @testset to stop after first @safetestset failure
@time @testset verbose = true "Krotov.jl Package" begin

println("\n* TLS Optimization (test_tls_optimization.jl)")
@time @safetestset "TLS Optimization" begin
include("test_tls_optimization.jl")
end

println("\n* Pulse Optimization (test_pulse_optimization.jl)")
@time @safetestset "Pulse Optimization" begin
include("test_pulse_optimization.jl")
Expand Down
155 changes: 155 additions & 0 deletions test/test_tls_optimization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
using Test
using QuantumControl
using QuantumPropagators: ExpProp
using QuantumControl.Functionals: J_T_sm
using GRAPE
import Krotov
using LinearAlgebra
using Printf
import IOCapture
using StaticArrays: @SMatrix, @SVector

ϵ(t) = 0.2 * QuantumControl.Shapes.flattop(t, T=5, t_rise=0.3, func=:blackman);


"""Two-level-system Hamiltonian."""
function tls_hamiltonian=1.0, ϵ=ϵ)
σ̂_z = ComplexF64[
1 0
0 -1
]
σ̂_x = ComplexF64[
0 1
1 0
]
Ĥ₀ = -0.5 * Ω * σ̂_z
Ĥ₁ = σ̂_x
return hamiltonian(Ĥ₀, (Ĥ₁, ϵ))
end;


"""Two-level-system Hamiltonian, using StaticArrays."""
function tls_hamiltonian_static=1.0, ϵ=ϵ)
σ̂_z = @SMatrix ComplexF64[
1 0
0 -1
]
σ̂_x = @SMatrix ComplexF64[
0 1
1 0
]
Ĥ₀ = -0.5 * Ω * σ̂_z
Ĥ₁ = σ̂_x
return hamiltonian(Ĥ₀, (Ĥ₁, ϵ))
end;


@testset "TLS" begin

println("\n==================== TLS ===========================\n")
H = tls_hamiltonian()
tlist = collect(range(0, 5, length=501))
Ψ₀ = ComplexF64[1, 0]
Ψtgt = ComplexF64[0, 1]
problem = ControlProblem(
[Trajectory(Ψ₀, H, target_state=Ψtgt)],
tlist;
iter_stop=5,
prop_method=ExpProp,
J_T=J_T_sm,
check_convergence=res -> begin
((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰"))
end,
)
res = optimize(problem; method=Krotov)
display(res)
@test res.J_T < 1e-3
@test 1.0 < maximum(abs.(res.optimized_controls[1])) < 1.2
println("===================================================\n")

end


@testset "TLS (static)" begin

println("\n================ TLS (static) ======================\n")
H = tls_hamiltonian_static()
tlist = collect(range(0, 5, length=501))
Ψ₀ = @SVector ComplexF64[1, 0]
Ψtgt = @SVector ComplexF64[0, 1]
problem = ControlProblem(
[Trajectory(Ψ₀, H, target_state=Ψtgt)],
tlist;
iter_stop=5,
prop_method=ExpProp,
J_T=J_T_sm,
check_convergence=res -> begin
((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰"))
end,
)
res = optimize(problem; method=Krotov)
display(res)
@test res.J_T < 1e-3
@test 1.0 < maximum(abs.(res.optimized_controls[1])) < 1.2
println("===================================================\n")

end



@testset "TLS (continue from GRAPE)" begin

println("\n============ TLS (GRAPE continuation) ============\n")
H = tls_hamiltonian()
tlist = collect(range(0, 5, length=501))
Ψ₀ = ComplexF64[1, 0]
Ψtgt = ComplexF64[0, 1]
problem = ControlProblem(
[Trajectory(Ψ₀, H, target_state=Ψtgt)],
tlist;
iter_stop=5,
prop_method=ExpProp,
J_T=J_T_sm,
check_convergence=res -> begin
((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰"))
end,
)
res_grape = optimize(problem; method=GRAPE, iter_stop=2)
res =
optimize(problem; method=Krotov, continue_from=res_grape, store_iter_info=["J_T"],)
display(res)
@test res.J_T < 1e-5
@test abs(res.records[1][1] - res_grape.J_T) < 1e-14
@test length(res.records) == 4
println("===================================================\n")

end


@testset "TLS (continue with GRAPE)" begin

println("\n=========== TLS (continue with GRAPE) ============\n")
H = tls_hamiltonian()
tlist = collect(range(0, 5, length=501))
Ψ₀ = ComplexF64[1, 0]
Ψtgt = ComplexF64[0, 1]
problem = ControlProblem(
[Trajectory(Ψ₀, H, target_state=Ψtgt)],
tlist;
iter_stop=5,
prop_method=ExpProp,
J_T=J_T_sm,
check_convergence=res -> begin
((res.J_T < 1e-10) && (res.converged = true) && (res.message = "J_T < 10⁻¹⁰"))
end,
)
res_krotov = optimize(problem; method=Krotov, iter_stop=2)
res =
optimize(problem; method=GRAPE, continue_from=res_krotov, store_iter_info=["J_T"],)
display(res)
@test res.J_T < 1e-3
@test length(res.records) == 4
@test abs(res.records[1][1] - res_krotov.J_T) < 1e-14
println("===================================================\n")

end

0 comments on commit 75b766b

Please sign in to comment.