Skip to content

Commit

Permalink
solvers and ode function profiling v1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsull committed Apr 11, 2023
1 parent 2479b12 commit 1e8f959
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 9 deletions.
173 changes: 165 additions & 8 deletions scripts/combined_experiment_iter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ k = Mpcfac*kclass #get k in our units
ℓ_mν=20
ℓ_ν=50
pertlen=2(ℓᵧ+1) + (ℓ_ν+1) + (ℓ_mν+1)*n_q + 5
reltol=1e-12
reltol=1e-6
#solve the hierarchy just to be sure
hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ, ℓ_ν, ℓ_mν,n_q);
results=zeros(pertlen,length(bg.x_grid));
Expand Down Expand Up @@ -73,8 +73,6 @@ end

planck_heavynu_ansatz_data = readdlm("./test/data/Bolt_allperts_mnu0p15_pholmax$(ℓᵧ)_msslsslmax$(ℓ_ν)_mssvlmax$(ℓ_mν).dat");



planck_pho_heavynu_ansatz_Θ₂ = linear_interpolation(planck_heavynu_ansatz_data[:,1],planck_heavynu_ansatz_data[:,2]);
planck_pho_heavynu_ansatz_Π = linear_interpolation(planck_heavynu_ansatz_data[:,1],planck_heavynu_ansatz_data[:,3]);
planck_nu_heavynu_ansatz₀ = [linear_interpolation(planck_heavynu_ansatz_data[:,1],planck_heavynu_ansatz_data[:,4]),
Expand Down Expand Up @@ -110,14 +108,9 @@ u0_ie = get_switch_u0(η_switch_use,hierarchy_conf,reltol);
M = 2048*4


h_boltsolve_conformal_flex(hierarchy_conf, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
initial_conditions(hierarchy_conf.hierarchy.bg.x_grid[1], hierarchy_conf.hierarchy),reltol=reltol);


xx_k,Θ₂,Π,𝒳₀_k,𝒳₂_k,perturb_k = itersolve(N_iters,cie_0,M,bg.x_grid[switch_idx],0.0,u0_ie;reltol=reltol);

#Accuracy/timing

@btime get_switch_u0(η_switch_use,hierarchy_conf,reltol);
#1.017 s (19372055 allocations: 439.42 MiB)
#Planck Ansatz (PA), Niter=2: 675.599 ms (19372055 allocations: 439.42 MiB)
Expand Down Expand Up @@ -153,8 +146,172 @@ xlims!(η_switch,bg.η[end])
ylabel!("Θ₂(η)")
xlabel!("η")

plot(perturb_conf.t,results_conf[1,:],color=:black,xscale=:log10)
plot!(bg.η(xx_k),perturb_k(bg.η(xx_k))[1,:],ls=:dash)
xlims!(1,1e4)
#--------------------------------
# Quick flamegraphs for ie! and hierarchy!


u0_h = initial_conditions(bg.x_grid[switch_idx], hierarchy_conf.hierarchy);
du_trunc_h = zero(u0_h);
println(size(du_trunc_h)) # (473,)
@btime Bolt.hierarchy!(du_trunc_h, u0_h, hierarchy, bg.x_grid[switch_idx]);
#48.005 μs (2421 allocations: 43.08 KiB)

du_trunc = zero(u0_ie);
println(size(du_trunc)) #(25,)
@btime Bolt.ie!(du_trunc, u0_ie, ie_0, bg.x_grid[switch_idx]);
# 25.497 μs (769 allocations: 20.19 KiB)

# A little surprising that the the truncated hierarchy is not doing much better than this since it is only 1/20th the size of the full hierarchy
# Let's look at flamegraphs to see why...


function hier_prof(N) #FIXME< can you just do this in a "begin" block?
for i in 1:N
Bolt.hierarchy!(du_trunc_h, u0_h, hierarchy, bg.x_grid[switch_idx]);
end
end
@profview hier_prof(10000)
#~1/2 of time spent in setting the massive neutrino derivative offset array
#~1/4 of time spent in neutrino ρ_σ integrals,
#~1/8 in massless neutrino and photon derivative offset array assignment
# small bit broadcasting final du at the end
# rest is negligible


function ie_prof(N)
for i in 1:N
Bolt.ie!(du_trunc, u0_ie, ie_0, bg.x_grid[switch_idx]);
end
end

@profview ie_prof(10000)
# Here 1/3 time in ρ_σ
# 1/6 maybe in massive dipole
# rest is about even between q computation, M0,M2 setting from interpolators, computing q points, etc.
# it looks like one of these is the iterator, so can probably remove that...
# can probably also save the q points somewehre?

#Ok so makes sense at this point, half the full time was coming from massive neutrino hierarchy
# and we totally eliminated that, but the next largest contribution is the ρ_σ integrals, which
# has the same cost in the ie!
# TODO optimizing these interals is the area to focus on:
# 1. Apply Zack's fix for background reusing
# 2. Perhaps save the values of q, ϵ, f0, and dlnf0dlnq so we don't need to recompute them constantly (same with temperature)
#^Add these to the background? (Perhaps best thing is to just save the products Iρ(xq[i]), Iσ(xq[i]) )
#^This only depends on qmin,qmax which is fixed when cosmology is known (i.e. in the bg), and has nothing to do with u vector
# 3. Can also explore more optimal choice of quadrature points...though this will be an assumption to ehcek

#--------------------------------
using ForwardDiff
# Check diffability of the ODE functions...
function test_AD_h(Ω_b::DT) where DT
𝕡 = CosmoParams{DT}(Ω_b=Ω_b)
bg = Background(𝕡; x_grid=-20.0:0.1:0.0, nq=15)
𝕣 = Bolt.RECFAST(bg=bg, Yp=𝕡.Y_p, OmegaB=𝕡.Ω_b,OmegaG=𝕡.Ω_r);
ih = IonizationHistory(𝕣, 𝕡, bg);
hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ, ℓ_ν, ℓ_mν,n_q);
η2x = linear_interpolation(bg.η,bg.x_grid);
hierarchy_conf = ConformalHierarchy(hierarchy,η2x);
switch_idx = argmin(abs.(bg.η .-η_switch))
u0_h = initial_conditions(bg.x_grid[switch_idx], hierarchy_conf.hierarchy);
du_trunc_h = zero(u0_h);
Bolt.hierarchy_conformal!(du_trunc_h, u0_h, hierarchy_conf, bg.η(bg.x_grid[switch_idx]));
return du_trunc_h[1]
end

test_AD_h(0.046)

ForwardDiff.derivative(test_AD_h, 0.046)
#Fix this here first


function test_AD_ie(Ω_b::DT) where DT
𝕡 = CosmoParams{DT}(Ω_b=Ω_b)
bg = Background(𝕡; x_grid=-20.0:0.1:0.0, nq=15)
𝕣 = Bolt.RECFAST(bg=bg, Yp=𝕡.Y_p, OmegaB=𝕡.Ω_b,OmegaG=𝕡.Ω_r);
ih = IonizationHistory(𝕣, 𝕡, bg);
hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ, ℓ_ν, ℓ_mν,n_q);
η2x = linear_interpolation(bg.η,bg.x_grid);
hierarchy_conf = ConformalHierarchy(hierarchy,η2x);
switch_idx = argmin(abs.(bg.η .-η_switch))
switch_idx = argmin(abs.(bg.η .-η_switch))
η2x_late = linear_interpolation(bg.η.(bg.x_grid[switch_idx:end]), bg.x_grid[switch_idx:end]);
ie_0 = IEγν(BasicNewtonian(), 𝕡, bg, ih, k,
planck_pho_heavynu_ansatz_Θ₂,planck_pho_heavynu_ansatz_Π,
planck_nu_heavynu_ansatz₀,planck_nu_heavynu_ansatz₂,
300,300,800, #exact optimal choice of these is k,η-dependent...
n_q);
cie_0 = ConformalIEγν(ie_0,η2x_late);
η_switch_use = bg.η[switch_idx];
u0_ie = get_switch_u0(η_switch_use,hierarchy_conf,reltol);
du_trunc = zero(u0_ie);
Bolt.ie!(du_trunc, u0_ie, ie_0, bg.x_grid[switch_idx]);
return du_trunc[1]
end

test_AD_ie(0.046) #in principle, this should be the same as the hiearchy derivative...why is it opposite sign??

ForwardDiff.derivative(test_AD_ie, 0.046)




#--------------------------------
# Quick check of a different solver for the truncated hierarchy, with a dependence on tolerance
# NB - NOT looking at accuracy rn
using OrdinaryDiffEq
reltol=1e-3;#1e-12;#1e-6; #1e-3 is almost certainly too loose, but 1e-6 is maybe passable

# Full hierarchy, KenCarp4
ode_alg = KenCarp4();
@btime h_boltsolve_conformal_flex(hierarchy_conf, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_h,ode_alg, reltol=reltol);
#1e-12 @btime 2.915 s (69272678 allocations: 1.36 GiB)
#1e-6 @btime 2.830 s (68685503 allocations: 1.34 GiB)
#1e-3 @btime 1.334 s (35940686 allocations: 765.21 MiB)

# Truncated hierarchy, KenCarp4
@btime boltsolve_conformal_flex(cie_0, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_ie,ode_alg, reltol=reltol);
#1e-12 @btime 1.036 s (20860835 allocations: 521.39 MiB)
#1e-6 @btime 558.906 ms (12224668 allocations: 305.68 MiB)
#1e-3 @btime 137.234 ms (3396905 allocations: 85.47 MiB)

# Full hierarchy, Rodas5P
ode_alg = Rodas5P();
@btime h_boltsolve_conformal_flex(hierarchy_conf, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_h,ode_alg, reltol=reltol);
#1e-12 @btime 40.638 s (1205521335 allocations: 26.62 GiB)
#1e-6 @btime 16.627 s (570224354 allocations: 12.61 GiB)
#1e-3 @btime 10.573 s (328733948 allocations: 7.27 GiB)

# Truncated hierarchy, Rodas5P
# FIXME getting AD error, come back to this after diffability test - not sure why this does not work, but hierarchy_conf does work?
@btime boltsolve_conformal_flex(cie_0, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_ie,ode_alg, reltol=reltol);
#1e-12 @btime N/A
#1e-6 @btime N/A
#1e-3 @btime N/A

# Full hierarchy, radau
ode_alg = RadauIIA5();
@btime h_boltsolve_conformal_flex(hierarchy_conf, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_h,ode_alg, reltol=reltol);
#1e-12 @btime 10.627 s (115457095 allocations: 2.54 GiB)
#1e-6 @btime 9.668 s (90545579 allocations: 1.95 GiB)
#1e-3 @btime 8.484 s (98425903 allocations: 2.14 GiB)

# Truncated hierarchy, radau
@btime boltsolve_conformal_flex(cie_0, η_switch_use, hierarchy_conf.hierarchy.bg.η(hierarchy_conf.hierarchy.bg.x_grid[end]),
u0_ie,ode_alg, reltol=reltol);
#1e-12 @btime 102.116 ms (2421616 allocations: 61.80 MiB)
#1e-6 @btime 258.186 ms (5807481 allocations: 147.26 MiB)
#1e-3 @btime 163.349 ms (3962663 allocations: 100.20 MiB)

#--------------------------------



Expand Down
1 change: 0 additions & 1 deletion src/perturbations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ struct ConformalHierarchy{T<:Real, H <: Hierarchy{T}, IT <: AbstractInterpolati
hierarchy::H
η2x::IT
end
#^Not sure this is the best way to wrap this...

function boltsolve(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6,abstol=1e-6) where T
xᵢ = first(hierarchy.bg.x_grid)
Expand Down

0 comments on commit 1e8f959

Please sign in to comment.