From e9a326edace4d20b48259ddf8d51e5086c7c4a6e Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Wed, 17 May 2023 16:40:38 -0700 Subject: [PATCH 01/10] take out neutrinos, no nn --- scripts/first_plin.jl | 2 +- src/Bolt.jl | 3 +- src/background.jl | 60 +-------- src/perturbations.jl | 287 ++++++------------------------------------ 4 files changed, 44 insertions(+), 308 deletions(-) diff --git a/scripts/first_plin.jl b/scripts/first_plin.jl index 6bf28a8..21fb465 100644 --- a/scripts/first_plin.jl +++ b/scripts/first_plin.jl @@ -1,4 +1,4 @@ -using Revise +# using Revise using Bolt using ForwardDiff using Plots diff --git a/src/Bolt.jl b/src/Bolt.jl index e2dd332..c164c21 100644 --- a/src/Bolt.jl +++ b/src/Bolt.jl @@ -60,8 +60,7 @@ abstract type AbstractCosmoParams{T} end A = 2.097e-9 # scalar amplitude, 1e-10*exp(3.043) n = 1.0 # scalar spectral index Y_p = 0.24 # primordial helium fraction - N_ν = 3.046 #effective number of relativisic species (PDG25 value) - Σm_ν = 0.06 #sum of neutrino masses (eV), Planck 15 default ΛCDM value + α_c = -3.0 # not sure if this would be a problem if it were an int? end include("util.jl") diff --git a/src/background.jl b/src/background.jl index bd56264..0713a70 100644 --- a/src/background.jl +++ b/src/background.jl @@ -7,59 +7,15 @@ H₀(par::AbstractCosmoParams) = par.h * km_s_Mpc_100 ρ_crit(par::AbstractCosmoParams) = (3 / 8π) * H₀(par)^2 / G_natural function Ω_Λ(par::AbstractCosmoParams) #Below can definitely be more streamlined, I am just making it work for now - Tγ = (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - νfac =(90 * ζ /(11 * π^4)) * (par.Ω_r * par.h^2 / Tγ) *((par.N_ν/3)^(3/4)) - #^the factor that goes into nr approx to neutrino energy density, plus equal sharing ΔN_eff factor for single massive neutrino - Ω_ν = par.Σm_ν*νfac/par.h^2 - return 1 - (par.Ω_r*(1+(2/3)*(7par.N_ν/8)*(4/11)^(4/3)) # dark energy density - + par.Ω_b + par.Ω_c - + Ω_ν - ) #assume massive nus are non-rel today -end - -#background FD phase space -function f0(q,par::AbstractCosmoParams) - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - gs = 2 #should be 2 for EACH neutrino family (mass eigenstate) - return gs / (2π)^3 / ( exp(q/Tν) +1) -end - -function dlnf0dlnq(q,par::AbstractCosmoParams) #this is actually only used in perts - Tν = (par.N_ν/3)^(1/4) * (4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - return -q / Tν /(1 + exp(-q/Tν)) -end - -#This is just copied from perturbations.jl for now - but take out Pressure - maybe later restore for FD tests? -function ρP_0(a,par::AbstractCosmoParams,quad_pts,quad_wts) - #Do q integrals to get the massive neutrino metric perturbations - #MB eqn (55) - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - #Not allowed to set Neff=0 o.w. breaks this #FIXME add an error message - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - #FIXME: avoid repeating code? and maybe put general integrals in utils? - m = par.Σm_ν - ϵx(x, am) = √(xq2q(x,logqmin,logqmax)^2 + (am)^2) - Iρ(x) = xq2q(x,logqmin,logqmax)^2 * ϵx(x, a*m) * f0(xq2q(x,logqmin,logqmax),par) / dxdq(xq2q(x,logqmin,logqmax),logqmin,logqmax) - IP(x) = xq2q(x,logqmin,logqmax)^2 * (xq2q(x,logqmin,logqmax)^2 /ϵx(x, a*m)) * f0(xq2q(x,logqmin,logqmax),par) / dxdq(xq2q(x,logqmin,logqmax),logqmin,logqmax) - xq,wq =quad_pts,quad_wts - ρ = 4π * a^(-4) * sum(Iρ.(xq).*wq) - P = 4π/3 * a^(-4) *sum(IP.(xq).*wq) - return ρ,P -end - -#neglect neutrinos, this is for ionization debugging purposes only -function oldH_a(a, par::AbstractCosmoParams) - return H₀(par) * √((par.Ω_c + par.Ω_b ) * a^(-3) - + par.Ω_r*(1+(2/3)*(7par.N_ν/8)*(4/11)^(4/3)) * a^(-4) - + Ω_Λ(par)) + return 1 - (par.Ω_r + par.Ω_b + par.Ω_c) end # Hubble parameter ȧ/a in Friedmann background function H_a(a, par::AbstractCosmoParams,quad_pts,quad_wts) ρ_ν,_ = ρP_0(a,par,quad_pts,quad_wts) #FIXME dropped pressure, need to decide if we want it for tests? - return H₀(par) * √((par.Ω_c + par.Ω_b ) * a^(-3) - + ρ_ν/ρ_crit(par) - + par.Ω_r* a^(-4)*(1+(2/3)*(7par.N_ν/8)*(4/11)^(4/3)) + return H₀(par) * √(par.Ω_c * a^par.α_c + + par.Ω_b * a^(-3) + + par.Ω_r* a^(-4) + Ω_Λ(par)) end # conformal time Hubble parameter, aH @@ -69,15 +25,7 @@ end H(x, par::AbstractCosmoParams,quad_pts,quad_wts) = H_a(x2a(x),par,quad_pts,quad_wts) ℋ(x, par::AbstractCosmoParams,quad_pts,quad_wts) = ℋ_a(x2a(x), par,quad_pts,quad_wts) -# conformal time -function η(x, par::AbstractCosmoParams,quad_pts,quad_wts) - logamin,logamax=-13.75,log10(x2a(x)) - Iη(y) = 1.0 / (xq2q(y,logamin,logamax) * ℋ_a(xq2q(y,logamin,logamax), par,quad_pts,quad_wts))/ dxdq(xq2q(y,logamin,logamax),logamin,logamax) - return sum(Iη.(quad_pts).*quad_wts) -end - # now build a Background with these functions - # a background is parametrized on the scalar type T, the interpolator type IT, # and a type for the grid GT abstract type AbstractBackground{T, IT<:AbstractInterpolation{T,1}, GT} end diff --git a/src/perturbations.jl b/src/perturbations.jl index 595ca1c..fbc2933 100644 --- a/src/perturbations.jl +++ b/src/perturbations.jl @@ -18,7 +18,7 @@ struct Hierarchy{T<:Real, PI<:PerturbationIntegrator, CP<:AbstractCosmoParams{T} end Hierarchy(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::AbstractBackground, - ih::AbstractIonizationHistory, k::Real, ℓᵧ=8, ℓ_ν=8, ℓ_mν=10, nq=15) = Hierarchy(integrator, par, bg, ih, k, ℓᵧ, ℓ_ν,ℓ_mν, nq) + ih::AbstractIonizationHistory, k::Real, ℓᵧ=8) = Hierarchy(integrator, par, bg, ih, k, ℓᵧ) @@ -32,165 +32,40 @@ function boltsolve(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6, abs return sol end -function rsa_perts!(u, hierarchy::Hierarchy{T},x) where T - #redundant code for what we need to compute RSA perts in place in u - k, ℓᵧ, par, bg, ih, nq = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih,hierarchy.nq - Ω_r, Ω_b, Ω_c, N_ν, m_ν, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, par.N_ν, par.Σm_ν, bg.H₀^2 #add N_ν≡N_eff - ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) - a = x2a(x) - Ω_ν = 7*(2/3)*N_ν/8 *(4/11)^(4/3) *Ω_r - csb² = ih.csb²(x) - ℓ_ν = hierarchy.ℓ_ν - Θ, Θᵖ, 𝒩, ℳ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) - # Θ′, Θᵖ′, 𝒩′, ℳ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 - - ρℳ, σℳ = ρ_σ(ℳ[0:nq-1], ℳ[2*nq:3*nq-1], bg, a, par) #monopole (energy density, 00 part),quadrupole (shear stress, ij part) - Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2]+ - Ω_ν * 𝒩[2] - + σℳ / bg.ρ_crit /4 - ) - Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( - Ω_c * a^(-1) * δ + Ω_b * a^(-1) * δ_b - + 4Ω_r * a^(-2) * Θ[0] - + 4Ω_ν * a^(-2) * 𝒩[0] - + a^(-2) * ρℳ / bg.ρ_crit - ) - - #fixed RSA - Θ[0] = Φ - ℋₓ/k *τₓ′ * v_b - Θ[1] = ℋₓ/k * ( -2Φ′ + τₓ′*( Φ - csb²*δ_b ) - + ℋₓ/k*( τₓ′′ - τₓ′ )*v_b ) - Θ[2] = 0 - #massless neutrinos - 𝒩[0] = Φ - 𝒩[1] = -2ℋₓ/k *Φ′ - 𝒩[2] = 0 - - u[1] = Θ[0] - u[2] = Θ[1] - u[3] = Θ[2] - - u[2(ℓᵧ+1)+1] = 𝒩[0] - u[2(ℓᵧ+1)+2] = 𝒩[1] - u[2(ℓᵧ+1)+3] = 𝒩[2] - - #zero the rest to avoid future confusion - for ℓ in 3:(ℓᵧ) - u[ℓ] = 0 - u[(ℓᵧ+1)+ℓ] = 0 - end - for ℓ in 3:(ℓ_ν) u[2(ℓᵧ+1)+ℓ] = 0 end - return nothing -end - -function boltsolve_rsa(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6, abstol=1e-6) where T - #call solve as usual first - perturb = boltsolve(hierarchy, reltol=reltol, abstol=abstol) - x_grid = hierarchy.bg.x_grid - pertlen = 2(hierarchy.ℓᵧ+1)+(hierarchy.ℓ_ν+1)+(hierarchy.ℓ_mν+1)*hierarchy.nq+5 - results=zeros(pertlen,length(x_grid)) - for i in 1:length(x_grid) results[:,i] = perturb(x_grid[i]) end - #replace the late-time perts with RSA approx (assuming we don't change rsa switch) - #this_rsa_switch = x_grid[argmin(abs.(hierarchy.k .* hierarchy.bg.η.(x_grid) .- 45))] - - xrsa_hor = findfirst(>(240), @. hierarchy.k * hierarchy.bg.η) - xrsa_od = findfirst(>(100), @. -hierarchy.ih.τ′*hierarchy.bg.ℋ/hierarchy.bg.η) - xrsa_hor = isnothing(xrsa_hor) ? length(x_grid) : xrsa_hor - xrsa_od = isnothing(xrsa_hor) ? length(x_grid) : xrsa_od - - this_rsa_switch = x_grid[max(xrsa_hor,xrsa_od)] - x_grid_rsa = x_grid[x_grid.>this_rsa_switch] - results_rsa = results[:,x_grid.>this_rsa_switch] - #(re)-compute the RSA perts so we can write them to the output vector - for i in 1:length(x_grid_rsa) - rsa_perts!(view(results_rsa,:,i),hierarchy,x_grid_rsa[i]) #to mutate need to use view... - end - results[:,x_grid.>this_rsa_switch] = results_rsa - sol = results - return sol -end # basic Newtonian gauge: establish the order of perturbative variables in the ODE solve function unpack(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T ℓᵧ = hierarchy.ℓᵧ - ℓ_ν = hierarchy.ℓ_ν - ℓ_mν = hierarchy.ℓ_mν #should be smaller than others - nq = hierarchy.nq Θ = OffsetVector(view(u, 1:(ℓᵧ+1)), 0:ℓᵧ) # indexed 0 through ℓᵧ Θᵖ = OffsetVector(view(u, (ℓᵧ+2):(2ℓᵧ+2)), 0:ℓᵧ) # indexed 0 through ℓᵧ - 𝒩 = OffsetVector(view(u, (2(ℓᵧ+1) + 1):(2(ℓᵧ+1)+ℓ_ν+1)) , 0:ℓ_ν) # indexed 0 through ℓ_ν - ℳ = OffsetVector(view(u, (2(ℓᵧ+1)+(ℓ_ν+1)+1):(2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq )) , 0:(ℓ_mν+1)*nq -1) # indexed 0 through ℓ_mν - Φ, δ, v, δ_b, v_b = view(u, ((2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq)+1 :(2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq)+5)) #getting a little messy... - return Θ, Θᵖ, 𝒩, ℳ, Φ, δ, v, δ_b, v_b -end - -function ρ_σ(ℳ0,ℳ2,bg,a,par::AbstractCosmoParams) #a mess - #Do q integrals to get the massive neutrino metric perturbations - #MB eqn (55) - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - #^Replace this with bg.ρ_crit? I think it is using an imported function ρ_crit - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - - #FIXME: avoid repeating code? and maybe put general integrals in utils? - m = par.Σm_ν - nq = length(ℳ0) #assume we got this right - ϵx(x, am) = √(xq2q(x,logqmin,logqmax)^2 + (am)^2) - Iρ(x) = xq2q(x,logqmin,logqmax)^2 * ϵx(x, a*m) * f0(xq2q(x,logqmin,logqmax),par) / dxdq(xq2q(x,logqmin,logqmax),logqmin,logqmax) - Iσ(x) = xq2q(x,logqmin,logqmax)^2 * (xq2q(x,logqmin,logqmax)^2 /ϵx(x, a*m)) * f0(xq2q(x,logqmin,logqmax),par) / dxdq(xq2q(x,logqmin,logqmax),logqmin,logqmax) - xq,wq = bg.quad_pts,bg.quad_wts - ρ = 4π*sum(Iρ.(xq).*ℳ0.*wq) - σ = 4π*sum(Iσ.(xq).*ℳ2.*wq) - # #a-dependence has been moved into Einstein eqns, as have consts in σ - return ρ,σ -end - -#need a separate function for θ (really(ρ̄+P̄)θ) for plin gauge change -function θ(ℳ1,bg,a,par::AbstractCosmoParams) #a mess - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *bg.ρ_crit *par.Ω_r)^(1/4) - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - m = par.Σm_ν - nq = length(ℳ1) #assume we got this right - Iθ(x) = xq2q(x,logqmin,logqmax)^3 * f0(xq2q(x,logqmin,logqmax),par) / dxdq(xq2q(x,logqmin,logqmax),logqmin,logqmax) - xq,wq = bg.quad_pts,bg.quad_wts - θ = 4π*sum(Iθ.(xq).*ℳ1.*wq) - #Note that this still needs to be multiplied with ka^-4 prefactor - return θ + Φ, δ, v, δ_b, v_b = view(u, (2(ℓᵧ+1)+1):(2(ℓᵧ+1)+5)) #getting a little messy... + return Θ, Θᵖ, Φ, δ, v, δ_b, v_b end # BasicNewtonian comes from Callin+06 and the Dodelson textbook (dispatches on hierarchy.integrator) function hierarchy!(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T # compute cosmological quantities at time x, and do some unpacking - k, ℓᵧ, par, bg, ih, nq = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih,hierarchy.nq - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - q_pts = xq2q.(bg.quad_pts,logqmin,logqmax) - Ω_r, Ω_b, Ω_c, N_ν, m_ν, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, par.N_ν, par.Σm_ν, bg.H₀^2 #add N_ν≡N_eff + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) a = x2a(x) R = 4Ω_r / (3Ω_b * a) - Ω_ν = 7*(2/3)*N_ν/8 *(4/11)^(4/3) *Ω_r csb² = ih.csb²(x) + α_c = par.α_c - ℓ_ν = hierarchy.ℓ_ν - ℓ_mν = hierarchy.ℓ_mν - Θ, Θᵖ, 𝒩, ℳ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) - Θ′, Θᵖ′, 𝒩′, ℳ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + Θ′, Θᵖ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 - #do the q integrals for massive neutrino perts (monopole and quadrupole) - ρℳ, σℳ = ρ_σ(ℳ[0:nq-1], ℳ[2*nq:3*nq-1], bg, a, par) #monopole (energy density, 00 part),quadrupole (shear stress, ij part) # metric perturbations (00 and ij FRW Einstein eqns) - Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2]+ - Ω_ν * 𝒩[2]#add rel quadrupole - + σℳ / bg.ρ_crit /4 + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] ) Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( - Ω_c * a^(-1) * δ + Ω_b * a^(-1) * δ_b + Ω_c * a^(2+α_c) * δ + + Ω_b * a^(-1) * δ_b + 4Ω_r * a^(-2) * Θ[0] - + 4Ω_ν * a^(-2) * 𝒩[0] #add rel monopole on this line - + a^(-2) * ρℳ / bg.ρ_crit ) # matter @@ -198,72 +73,27 @@ function hierarchy!(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T v′ = -v - k / ℋₓ * Ψ δ_b′ = k / ℋₓ * v_b - 3Φ′ v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + - # neutrinos (massive, MB 57) - for (i_q, q) in zip(Iterators.countfrom(0), q_pts) - ϵ = √(q^2 + (a*m_ν)^2) - df0 = dlnf0dlnq(q,par) - #need these factors of 4 on Φ, Ψ terms due to MB pert defn - ℳ′[0* nq+i_q] = - k / ℋₓ * q/ϵ * ℳ[1* nq+i_q] + Φ′ * df0 - ℳ′[1* nq+i_q] = k / (3ℋₓ) * ( q/ϵ * (ℳ[0* nq+i_q] - 2ℳ[2* nq+i_q]) - ϵ/q * Ψ * df0) - for ℓ in 2:(ℓ_mν-1) - ℳ′[ℓ* nq+i_q] = k / ℋₓ * q / ((2ℓ+1)*ϵ) * ( ℓ*ℳ[(ℓ-1)* nq+i_q] - (ℓ+1)*ℳ[(ℓ+1)* nq+i_q] ) - end - ℳ′[ℓ_mν* nq+i_q] = q / ϵ * k / ℋₓ * ℳ[(ℓ_mν-1)* nq+i_q] - (ℓ_mν+1)/(ℋₓ *ηₓ) *ℳ[(ℓ_mν)* nq+i_q] #MB (58) similar to rel case but w/ q/ϵ + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) end - # RSA equations (implementation of CLASS default switches) - rsa_on = (k*ηₓ > 240) && (-τₓ′*ℋₓ / ηₓ > 100) - if rsa_on - #photons - Θ[0] = Φ - ℋₓ/k *τₓ′ * v_b - Θ[1] = -2Φ′/k + (k^-2)*( τₓ′′ * v_b + τₓ′ * (ℋₓ*v_b - csb² *δ_b/k + k*Φ) ) - Θ[1] = ℋₓ/k * ( -2Φ′ + τₓ′*( Φ - csb²*δ_b ) - + ℋₓ/k*( τₓ′′ - τₓ′ )*v_b ) - Θ[2] = 0 - #massless neutrinos - 𝒩[0] = Φ - 𝒩[1] = -2ℋₓ/k *Φ′ - 𝒩[2] = 0 - - # manual zeroing to avoid saving garbage - 𝒩′[:] = zeros(ℓ_ν+1) - Θ′[:] = zeros(ℓᵧ+1) - Θᵖ′[:] = zeros(ℓᵧ+1) - - else - #do usual hierarchy - # relativistic neutrinos (massless) - 𝒩′[0] = -k / ℋₓ * 𝒩[1] - Φ′ - 𝒩′[1] = k/(3ℋₓ) * 𝒩[0] - 2*k/(3ℋₓ) *𝒩[2] + k/(3ℋₓ) *Ψ - for ℓ in 2:(ℓ_ν-1) - 𝒩′[ℓ] = k / ((2ℓ+1) * ℋₓ) * ( ℓ*𝒩[ℓ-1] - (ℓ+1)*𝒩[ℓ+1] ) - end - #truncation (same between MB and Callin06/Dodelson) - 𝒩′[ℓ_ν] = k / ℋₓ * 𝒩[ℓ_ν-1] - (ℓ_ν+1)/(ℋₓ *ηₓ) *𝒩[ℓ_ν] - - - # photons - Π = Θ[2] + Θᵖ[2] + Θᵖ[0] - Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ - Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) - for ℓ in 2:(ℓᵧ-1) - Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - - (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) - end - - # polarized photons - Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) - for ℓ in 1:(ℓᵧ-1) - Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - - (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) - end - - # photon boundary conditions: diffusion damping - Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] - Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] - + # polarized photons + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) end + + # photon boundary conditions: diffusion damping + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] #END RSA du[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+1:2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in @@ -272,20 +102,14 @@ end # BasicNewtonian Integrator (dispatches on hierarchy.integrator) function initial_conditions(xᵢ, hierarchy::Hierarchy{T, BasicNewtonian}) where T - k, ℓᵧ, par, bg, ih, nq = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih, hierarchy.nq - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - q_pts = xq2q.(bg.quad_pts,logqmin,logqmax) - ℓ_ν = hierarchy.ℓ_ν - ℓ_mν = hierarchy.ℓ_mν - u = zeros(T, 2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+5) + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + u = zeros(T, 2(ℓᵧ+1)+5) ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(xᵢ), bg.ℋ′(xᵢ), bg.η(xᵢ), ih.τ′(xᵢ), ih.τ′′(xᵢ) - Θ, Θᵖ, 𝒩, ℳ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) H₀²,aᵢ² = bg.H₀^2,exp(xᵢ)^2 aᵢ = sqrt(aᵢ²) #These get a 3/3 since massive neutrinos behave as massless at time of ICs - Ω_ν = 7*(3/3)*par.N_ν/8 *(4/11)^(4/3) *par.Ω_r - f_ν = 1/(1 + 1/(7*(3/3)*par.N_ν/8 *(4/11)^(4/3))) + f_ν = 1/(1 + 1/(7*(3/3)*3.046/8 *(4/11)^(4/3))) # we need to actually keep this # metric and matter perturbations ℛ = 1.0 # set curvature perturbation to 1 @@ -311,29 +135,7 @@ function initial_conditions(xᵢ, hierarchy::Hierarchy{T, BasicNewtonian}) where v = -3k*Θ[1] v_b = v - # neutrino hierarchy - # we need xᵢ to be before neutrinos decouple, as always - 𝒩[0] = Θ[0] - 𝒩[1] = Θ[1] - 𝒩[2] = - (k^2 *ηₓ^2)/15 * 1 / (1 + 2/5 *f_ν) * Φ / 2 #MB - for ℓ in 3:ℓ_ν - 𝒩[ℓ] = k/((2ℓ+1)ℋₓ) * 𝒩[ℓ-1] #standard truncation - end - - #massive neutrino hierarchy - #It is confusing to use Ψℓ bc Ψ is already the metric pert, so will use ℳ - for (i_q, q) in zip(Iterators.countfrom(0), q_pts) - ϵ = √(q^2 + (aᵢ*par.Σm_ν)^2) - df0 = dlnf0dlnq(q,par) - ℳ[0* nq+i_q] = -𝒩[0] *df0 - ℳ[1* nq+i_q] = -ϵ/q * 𝒩[1] *df0 - ℳ[2* nq+i_q] = -𝒩[2] *df0 #drop quadratic+ terms in (ma/q) as in MB - for ℓ in 3:ℓ_mν #same scheme for higher-ell as for relativistic - ℳ[ℓ* nq+i_q] = q / ϵ * k/((2ℓ+1)ℋₓ) * ℳ[(ℓ-1)*nq+i_q] #approximation equivalent to MB, but add q/ϵ - leaving as 0 makes no big difference - end - end - - u[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+1:(2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+5)] .= Φ, δ, v, δ_b, v_b # write u with our variables + u[2(ℓᵧ+1)+1:(2(ℓᵧ+1)+5)] .= Φ, δ, v, δ_b, v_b # write u with our variables return u end @@ -342,32 +144,19 @@ end # Bardeen potential Ψ and its derivative ψ′ for an integrator, or we saved them function source_function(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T # compute some quantities - k, ℓᵧ, par, bg, ih,nq = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih,hierarchy.nq + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih H₀² = bg.H₀^2 ℋₓ, ℋₓ′, ℋₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.ℋ′′(x) τₓ, τₓ′, τₓ′′ = ih.τ(x), ih.τ′(x), ih.τ′′(x) g̃ₓ, g̃ₓ′, g̃ₓ′′ = ih.g̃(x), ih.g̃′(x), ih.g̃′′(x) a = x2a(x) - ρ0ℳ = bg.ρ₀ℳ(x) #get current value of massive neutrino backround density from spline - Tν = (par.N_ν/3)^(1/4) *(4/11)^(1/3) * (15/ π^2 *ρ_crit(par) *par.Ω_r)^(1/4) - Ω_ν = 7*(2/3)*par.N_ν/8 *(4/11)^(4/3) *par.Ω_r - logqmin,logqmax=log10(Tν/30),log10(Tν*30) - q_pts = xq2q.(bg.quad_pts,logqmin,logqmax) - Θ, Θᵖ, 𝒩, ℳ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) - Θ′, Θᵖ′, 𝒩′, ℳ′, Φ′, δ′, v′, δ_b′, v_b′ = unpack(du, hierarchy) + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + Θ′, Θᵖ′, Φ′, δ′, v′, δ_b′, v_b′ = unpack(du, hierarchy) - # recalulate these since we didn't save them (Callin eqns 39-42) - #^Also have just copied from before, but should save these maybe? - _, σℳ = ρ_σ(ℳ[0:nq-1], ℳ[2*nq:3*nq-1], bg, a, par) #monopole (energy density, 00 part),quadrupole (shear stress, ij part) - _, σℳ′ = ρ_σ(ℳ′[0:nq-1], ℳ′[2*nq:3*nq-1], bg, a, par) - Ψ = -Φ - 12H₀² / k^2 / a^2 * (par.Ω_r * Θ[2] - + Ω_ν * 𝒩[2] #add rel quadrupole - + σℳ / bg.ρ_crit) #why am I doing this? - because H0 pulls out a factor of rho crit - just unit conversion + Ψ = -Φ - 12H₀² / k^2 / a^2 * (par.Ω_r * Θ[2]) #why am I doing this? - because H0 pulls out a factor of rho crit - just unit conversion #this introduces a factor of bg density I cancel using the integrated bg mnu density now - Ψ′ = -Φ′ - 12H₀² / k^2 / a^2 * (par.Ω_r * (Θ′[2] - 2 * Θ[2]) - + Ω_ν * (𝒩′[2] - 2 * 𝒩[2]) - + (σℳ′ - 2 * σℳ) / bg.ρ_crit /4 ) + Ψ′ = -Φ′ - 12H₀² / k^2 / a^2 * (par.Ω_r * (Θ′[2] - 2 * Θ[2])) Π = Θ[2] + Θᵖ[2] + Θᵖ[0] Π′ = Θ′[2] + Θᵖ′[2] + Θᵖ′[0] From e8705835a5aa7a413aa67c1632774a8d7b735499 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Thu, 25 May 2023 15:11:44 -0400 Subject: [PATCH 02/10] 525 --- src/Bolt.jl | 4 +- src/background.jl | 21 +++-- src/ionization/ionization.jl | 130 +++++++++++++++++++++++++++-- src/ionization/recfast.jl | 124 ++++++++++++++-------------- src/perturbations.jl | 155 ++++++++++++++++++++++++++++++++++- src/spectra.jl | 28 ++----- 6 files changed, 355 insertions(+), 107 deletions(-) diff --git a/src/Bolt.jl b/src/Bolt.jl index c164c21..b5d72b1 100644 --- a/src/Bolt.jl +++ b/src/Bolt.jl @@ -3,8 +3,8 @@ module Bolt export CosmoParams, AbstractCosmoParams export Background, AbstractBackground export IonizationHistory, AbstractIonizationHistory, IonizationIntegrator -export Peebles, PeeblesI -export ρ_σ,ρP_0,f0,dlnf0dlnq,θ,oldH_a #FIXME: quick hack to look at perts +export Peebles, PeeblesI,ihPeebles +export ρ_σ,ρP_0,f0,dlnf0dlnq,θ,oldH_a,H_a #FIXME: quick hack to look at perts export Hierarchy, boltsolve, BasicNewtonian,unpack,rsa_perts!,boltsolve_rsa export IE,initial_conditions,unpack,ie_unpack export source_grid, quadratic_k, cltt,log10_k,plin diff --git a/src/background.jl b/src/background.jl index 0713a70..3478faf 100644 --- a/src/background.jl +++ b/src/background.jl @@ -11,19 +11,25 @@ function Ω_Λ(par::AbstractCosmoParams) end # Hubble parameter ȧ/a in Friedmann background -function H_a(a, par::AbstractCosmoParams,quad_pts,quad_wts) - ρ_ν,_ = ρP_0(a,par,quad_pts,quad_wts) #FIXME dropped pressure, need to decide if we want it for tests? +function H_a(a, par::AbstractCosmoParams) return H₀(par) * √(par.Ω_c * a^par.α_c + par.Ω_b * a^(-3) + par.Ω_r* a^(-4) + Ω_Λ(par)) end # conformal time Hubble parameter, aH -ℋ_a(a, par::AbstractCosmoParams,quad_pts,quad_wts) = a * H_a(a, par,quad_pts,quad_wts) +ℋ_a(a, par::AbstractCosmoParams) = a * H_a(a, par) # functions in terms of x -H(x, par::AbstractCosmoParams,quad_pts,quad_wts) = H_a(x2a(x),par,quad_pts,quad_wts) -ℋ(x, par::AbstractCosmoParams,quad_pts,quad_wts) = ℋ_a(x2a(x), par,quad_pts,quad_wts) +H(x, par::AbstractCosmoParams) = H_a(x2a(x),par) +ℋ(x, par::AbstractCosmoParams) = ℋ_a(x2a(x), par) + +# conformal time +function η(x, par::AbstractCosmoParams,quad_pts,quad_wts) + logamin,logamax=-13.75,log10(x2a(x)) + Iη(y) = 1.0 / (xq2q(y,logamin,logamax) * ℋ_a(xq2q(y,logamin,logamax), par))/ dxdq(xq2q(y,logamin,logamax),logamin,logamax) + return sum(Iη.(quad_pts).*quad_wts) +end # now build a Background with these functions # a background is parametrized on the scalar type T, the interpolator type IT, @@ -46,13 +52,11 @@ struct Background{T, IT, GT} <: AbstractBackground{T, IT, GT} η::IT η′::IT η′′::IT - ρ₀ℳ::IT end function Background(par::AbstractCosmoParams{T}; x_grid=-20.0:0.01:0.0, nq=15) where T quad_pts, quad_wts = gausslegendre( nq ) - ρ₀ℳ_ = spline([ρP_0(x2a(x), par,quad_pts,quad_wts)[1] for x in x_grid], x_grid) - ℋ_ = spline([ℋ(x, par,quad_pts,quad_wts) for x in x_grid], x_grid) + ℋ_ = spline([ℋ(x, par) for x in x_grid], x_grid) η_ = spline([η(x, par,quad_pts,quad_wts) for x in x_grid], x_grid) return Background( T(H₀(par)), @@ -71,6 +75,5 @@ function Background(par::AbstractCosmoParams{T}; x_grid=-20.0:0.01:0.0, nq=15) w η_, spline_∂ₓ(η_, x_grid), spline_∂ₓ²(η_, x_grid), - ρ₀ℳ_, ) end diff --git a/src/ionization/ionization.jl b/src/ionization/ionization.jl index 2375fa5..03a23d5 100644 --- a/src/ionization/ionization.jl +++ b/src/ionization/ionization.jl @@ -5,9 +5,9 @@ const PeeblesT₀ = ustrip(natural(2.725u"K")) # CMB temperature [K] # TODO: make this a parameter of the ionization n_b(a, par) = par.Ω_b * ρ_crit(par) / (m_H * a^3) n_H(a, par) = n_b(a, par) *(1-par.Y_p) #Adding Helium! -saha_T_b(a, par) = PeeblesT₀ / a #j why does this take par? -saha_rhs(a, par) = (m_e * saha_T_b(a, par) / 2π)^(3/2) / n_H(a, par) * - exp(-ε₀_H / saha_T_b(a, par)) # rhs of Callin06 eq. 12 +saha_T_b(a) = PeeblesT₀ / a #j why does this take par? +saha_rhs(a, par) = (m_e * saha_T_b(a) / 2π)^(3/2) / n_H(a, par) * + exp(-ε₀_H / saha_T_b(a)) # rhs of Callin06 eq. 12 function saha_Xₑ(x, par::AbstractCosmoParams) rhs = saha_rhs(x2a(x), par) @@ -27,6 +27,17 @@ const m_H = ustrip(natural(float(ProtonMass))) const α = ustrip(natural(float(FineStructureConstant))) const σ_T = ustrip(natural(float(ThomsonCrossSection))) +const C_rf = 2.99792458e8 +const k_B_rf = 1.380658e-23 +const m_H_rf = 1.673575e-27 +const not4_rf = 3.9715e0 +const xinitial_RECFAST = z2x(10000.0) +const sigma = 6.6524616e-29 +const m_e_rf = 9.1093897e-31 +const zre_ini = 50.0 +const tol_rf = 1e-8 +# const Kelvin_natural_unit_conversion = # this is defined in recfast + # auxillary equations ϕ₂(T_b) = 0.448 * log(ε₀_H / T_b) α⁽²⁾(T_b) = (64π / √(27π)) * (α^2 / m_e^2) * √(ε₀_H / T_b) * ϕ₂(T_b) @@ -34,7 +45,7 @@ const σ_T = ustrip(natural(float(ThomsonCrossSection))) β⁽²⁾(T_b) = β(T_b) * exp(3ε₀_H / 4T_b) n₁ₛ(a, Xₑ, par) = (1 - Xₑ) * n_H(a, par) #Problem is here \/ since Lyα rate is given by redshifting out of line need H -Λ_α(a, Xₑ, par) = oldH_a(a, par) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) +Λ_α(a, Xₑ, par) = H_a(a, par) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) new_Λ_α(a, Xₑ, par, ℋ_function) = ℋ_function(log(a)) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) Cᵣ(a, Xₑ, T_b, par) = (Λ_2s_to_1s + Λ_α(a, Xₑ, par)) / ( Λ_2s_to_1s + Λ_α(a, Xₑ, par) + β⁽²⁾(T_b)) @@ -42,10 +53,10 @@ new_Cᵣ(a, Xₑ, T_b, par,ℋ_function) = (Λ_2s_to_1s + new_Λ_α(a, Xₑ, par Λ_2s_to_1s + new_Λ_α(a, Xₑ, par,ℋ_function) + β⁽²⁾(T_b)) # RHS of Callin06 eq. 13 -function peebles_Xₑ′(Xₑ, par, x) +function peebles_Xₑ′(Xₑ, par::CosmoParams{T}, x) where T a = exp(x) - T_b_a = BigFloat(saha_T_b(a, par)) # handle overflows by switching to bigfloat - return float(Cᵣ(a, Xₑ, T_b_a, par) / oldH_a(a, par) * ( + T_b_a = BigFloat(saha_T_b(a)) # handle overflows by switching to bigfloat + return T(Cᵣ(a, Xₑ, T_b_a, par) / H_a(a, par) * ( β(T_b_a) * (1 - Xₑ) - n_H(a, par) * α⁽²⁾(T_b_a) * Xₑ^2)) end @@ -129,7 +140,7 @@ function τ′(x, Xₑ_function, par, ℋ_function) end function oldτ′(x, Xₑ_function, par) a = x2a(x) - return -Xₑ_function(x) * n_H(a, par) * a * σ_T / (a*oldH_a(a,par)) + return -Xₑ_function(x) * n_H(a, par) * a * σ_T / (a*H_a(a,par)) end function g̃_function(τ_x_function, τ′_x_function) @@ -162,3 +173,106 @@ function customion(par, bg, Xₑ_function, Tmat_function, csb²_function) csb²_, ) end + +#----------------------- + +function reionization_Xe(𝕡::CosmoParams, Xe_func, z) + X_fin = 1 + 𝕡.Y_p / ( not4_rf*(1-𝕡.Y_p) ) #ionization frac today + zre,α,ΔH,zHe,ΔHe,fHe = 7.6711,1.5,0.5,3.5,0.5,X_fin-1 #reion params, TO REPLACE + x_orig = Xe_func(z2x(z)) + x_reio_H = (X_fin - x_orig) / 2 * ( + 1 + tanh(( (1+zre)^α - (1+z)^α ) / ( α*(1+zre)^(α-1) ) / ΔH)) + x_orig + x_reio_He = fHe / 2 * ( 1 + tanh( (zHe - z) / ΔHe) ) + x_reio = x_reio_H + x_reio_He + return x_reio +end + +function reionization_Tmat_ode(Tm,z) + x_reio = reionization_Xe(𝕡, Xe_func,z) + a = 1 / (1+z) + x_a = a2x(a) + Hz = bg.ℋ(x_a) / a / H0_natural_unit_conversion + Trad =Tnow_rf * (1+z) + CT_rf = (8/3)*(sigma/(m_e_rf*C_rf))*a + fHe = 𝕡.Y_p/(not4_rf*(1 - 𝕡.Y_p)) + dTm = CT_rf * Trad^4 * x_reio/(1 + x_reio + fHe) * + (Tm - Trad) / (Hz * (1 + z)) + 2 * Tm / (1 + z) + return dTm +end + +function tanh_reio_solve(Tmat0; zre_ini=50.0,zfinal=0.0) + reio_prob = ODEProblem(reionization_Tmat_ode, + Tmat0, + (zre_ini, zfinal)) + sol_reio_Tmat = solve(reio_prob, Tsit5(), reltol=tol_rf) + trh = TanhReionizationHistory(zre_ini, ion_hist, sol_reio_Tmat); + return trh +end + + +# struct Peebles_hist{T, AB<:AbstractBackground{T},CT<:AbstractCosmoParams{T}} <: IonizationIntegrator +# par::CT +# bg::AB +# Xe +# end + +function ihPeebles(par::AbstractCosmoParams{T}, bg::AbstractBackground{T};zfinal=0.0) where T + + x_grid = bg.x_grid + Xₑ_function = saha_peebles_recombination(par) + τ, τ′ = τ_functions(x_grid, Xₑ_function, par, bg.ℋ) + g̃ = Bolt.g̃_function(τ, τ′) + spline, spline_∂ₓ, spline_∂ₓ² = Bolt.spline, Bolt.spline_∂ₓ, Bolt.spline_∂ₓ² + Xₑ_ = spline(Xₑ_function.(x_grid), x_grid) + τ_ = spline(τ.(x_grid), x_grid) + g̃_ = spline(g̃.(x_grid), x_grid) + + Tnow_rf = (15/ π^2 *bg.ρ_crit * par.Ω_r)^(1/4) * Kelvin_natural_unit_conversion #last thing is natural to K + Trad_function = x -> Tnow_rf * (1 + x2z(x)) + + # trhist = tanh_reio_solve(rhist) + Tmat0=Trad_function(xinitial_RECFAST) #FIXME CHECK + + function reionization_Tmat_ode(Tm,p,z) + x_reio = reionization_Xe(par, Xₑ_,z) + Trad = Tnow_rf * (1 + z) + Hz = bg.ℋ(z2x(z)) * (1 + z) / H0_natural_unit_conversion + a=z2a(z) + CT_rf = (8/3)*(sigma/(m_e_rf*C_rf))*a + fHe = par.Y_p/(not4_rf*(1 - par.Y_p)) + return CT_rf * Trad^4 * x_reio/(1 + x_reio + fHe) * + (Tm - Trad) / (Hz * (1 + z)) + 2 * Tm / (1 + z) + end + zre_ini=50.0 + reio_prob = ODEProblem(reionization_Tmat_ode, + Tmat0, + (zre_ini, zfinal)) + sol_reio_Tmat = solve(reio_prob, Tsit5(), reltol=tol_rf) + # trh = TanhReionizationHistory(zre_ini, ion_hist, sol_reio_Tmat) + + Tmat_function = x -> (x < z2x(zre_ini)) ? + Trad_function(x) : sol_reio_Tmat(x2z(x)) + + Tmat_ = spline(Tmat_function.(x_grid), x_grid) + Yp = par.Y_p + mu_T_rf = not4_rf/(not4_rf-(not4_rf-1)*Yp) + csb²_pre = @.( C_rf^-2 * k_B_rf/m_H_rf * ( 1/mu_T_rf + (1-Yp)*Xₑ_(x_grid) ) ) #not the most readable... + #FIXME probably this is a bad way to do this... + csb²_ = spline(csb²_pre .* (Tmat_.(x_grid) .- 1/3 *spline_∂ₓ(Tmat_, x_grid).(x_grid)),x_grid) + # csb²_ = spline(csb²_function.(x_grid), x_grid) + + # println("typeof(Xₑ_) $(typeof(Xₑ_))") + # println("typeof(τ_) $(typeof(τ_))") + return IonizationHistory( + T(τ(0.)), + Xₑ_, + τ_, + spline_∂ₓ(τ_, x_grid), + spline_∂ₓ²(τ_, x_grid), + g̃_, + spline_∂ₓ(g̃_, x_grid), + spline_∂ₓ²(g̃_, x_grid), + Tmat_, + csb²_, + ) +end diff --git a/src/ionization/recfast.jl b/src/ionization/recfast.jl index 8c9dcd7..1a90b1c 100644 --- a/src/ionization/recfast.jl +++ b/src/ionization/recfast.jl @@ -21,80 +21,80 @@ end @with_kw struct RECFAST{T, AB<:AbstractBackground{T}} <: IonizationIntegrator bg::AB # a RECFAST has an associated background evolution - C::T = 2.99792458e8 # Fundamental constants in SI units - k_B::T = 1.380658e-23 - h_P::T = 6.6260755e-34 - m_e::T = 9.1093897e-31 - m_H::T = 1.673575e-27 # av. H atom + C::Float64 = 2.99792458e8 # Fundamental constants in SI units + k_B::Float64 = 1.380658e-23 + h_P::Float64 = 6.6260755e-34 + m_e::Float64 = 9.1093897e-31 + m_H::Float64 = 1.673575e-27 # av. H atom # note: neglecting deuterium, making an O(e-5) effect - not4::T = 3.9715e0 # mass He/H atom ("not4" pointed out by Gary Steigman) - sigma::T = 6.6524616e-29 - a::T = 7.565914e-16 - G::T = 6.6742e-11 # new value - - Lambda::T = 8.2245809e0 - Lambda_He::T = 51.3e0 # new value from Dalgarno - L_H_ion::T = 1.096787737e7 # level for H ion. (in m^-1) - L_H_alpha::T = 8.225916453e6 # averaged over 2 levels - L_He1_ion::T = 1.98310772e7 # from Drake (1993) - L_He2_ion::T = 4.389088863e7 # from JPhysChemRefData (1987) - L_He_2s::T = 1.66277434e7 # from Drake (1993) - L_He_2p::T = 1.71134891e7 # from Drake (1993) + not4::Float64 = 3.9715e0 # mass He/H atom ("not4" pointed out by Gary Steigman) + sigma::Float64 = 6.6524616e-29 + a::Float64 = 7.565914e-16 + G::Float64 = 6.6742e-11 # new value + + Lambda::Float64 = 8.2245809e0 + Lambda_He::Float64 = 51.3e0 # new value from Dalgarno + L_H_ion::Float64 = 1.096787737e7 # level for H ion. (in m^-1) + L_H_alpha::Float64 = 8.225916453e6 # averaged over 2 levels + L_He1_ion::Float64 = 1.98310772e7 # from Drake (1993) + L_He2_ion::Float64 = 4.389088863e7 # from JPhysChemRefData (1987) + L_He_2s::Float64 = 1.66277434e7 # from Drake (1993) + L_He_2p::Float64 = 1.71134891e7 # from Drake (1993) # C 2 photon rates and atomic levels in SI units - A2P_s::T = 1.798287e9 # Morton, Wu & Drake (2006) - A2P_t::T = 177.58e0 # Lach & Pachuski (2001) - L_He_2Pt::T = 1.690871466e7 # Drake & Morton (2007) - L_He_2St::T = 1.5985597526e7 # Drake & Morton (2007) - L_He2St_ion::T = 3.8454693845e6 # Drake & Morton (2007) - sigma_He_2Ps::T = 1.436289e-22 # Hummer & Storey (1998) - sigma_He_2Pt::T = 1.484872e-22 # Hummer & Storey (1998) + A2P_s::Float64 = 1.798287e9 # Morton, Wu & Drake (2006) + A2P_t::Float64 = 177.58e0 # Lach & Pachuski (2001) + L_He_2Pt::Float64 = 1.690871466e7 # Drake & Morton (2007) + L_He_2St::Float64 = 1.5985597526e7 # Drake & Morton (2007) + L_He2St_ion::Float64 = 3.8454693845e6 # Drake & Morton (2007) + sigma_He_2Ps::Float64 = 1.436289e-22 # Hummer & Storey (1998) + sigma_He_2Pt::Float64 = 1.484872e-22 # Hummer & Storey (1998) # C Atomic data for HeI - AGauss1::T = -0.14e0 # Amplitude of 1st Gaussian - AGauss2::T = 0.079e0 # Amplitude of 2nd Gaussian - zGauss1::T = 7.28e0 # ln(1+z) of 1st Gaussian - zGauss2::T = 6.73e0 # ln(1+z) of 2nd Gaussian - wGauss1::T = 0.18e0 # Width of 1st Gaussian - wGauss2::T = 0.33e0 # Width of 2nd Gaussian + AGauss1::Float64 = -0.14e0 # Amplitude of 1st Gaussian + AGauss2::Float64 = 0.079e0 # Amplitude of 2nd Gaussian + zGauss1::Float64 = 7.28e0 # ln(1+z) of 1st Gaussian + zGauss2::Float64 = 6.73e0 # ln(1+z) of 2nd Gaussian + wGauss1::Float64 = 0.18e0 # Width of 1st Gaussian + wGauss2::Float64 = 0.33e0 # Width of 2nd Gaussian # Gaussian fits for extra H physics (fit by Adam Moss, modified by Antony Lewis) # the Pequignot, Petitjean & Boisson fitting parameters for Hydrogen - a_PPB::T = 4.309 - b_PPB::T = -0.6166 - c_PPB::T = 0.6703 - d_PPB::T = 0.5300 + a_PPB::Float64 = 4.309 + b_PPB::Float64 = -0.6166 + c_PPB::Float64 = 0.6703 + d_PPB::Float64 = 0.5300 # the Verner and Ferland type fitting parameters for Helium # fixed to match those in the SSS papers, and now correct - a_VF::T = 10^(-16.744) - b_VF::T = 0.711 - T_0::T = 10^(0.477121) #!3K - T_1::T = 10^(5.114) + a_VF::Float64 = 10^(-16.744) + b_VF::Float64 = 0.711 + T_0::Float64 = 10^(0.477121) #!3K + T_1::Float64 = 10^(5.114) # fitting parameters for HeI triplets # (matches Hummer's table with <1% error for 10^2.8 < T/K < 10^4) - a_trip::T = 10^(-16.306) - b_trip::T = 0.761 + a_trip::Float64 = 10^(-16.306) + b_trip::Float64 = 0.761 # Set up some constants so they don't have to be calculated later - Lalpha::T = 1/L_H_alpha - Lalpha_He::T = 1/L_He_2p - DeltaB::T = h_P*C*(L_H_ion-L_H_alpha) - CDB::T = DeltaB/k_B - DeltaB_He::T = h_P*C*(L_He1_ion-L_He_2s) # 2s, not 2p - CDB_He::T = DeltaB_He/k_B - CB1::T = h_P*C*L_H_ion/k_B - CB1_He1::T = h_P*C*L_He1_ion/k_B # ionization for HeI - CB1_He2::T = h_P*C*L_He2_ion/k_B # ionization for HeII - CR::T = 2π * (m_e/h_P)*(k_B/h_P) - CK::T = Lalpha^3/(8π) - CK_He::T = Lalpha_He^3/(8π) - CL::T = C*h_P/(k_B*Lalpha) - CL_He::T = C*h_P/(k_B/L_He_2s) # comes from det.bal. of 2s-1s - CT::T = (8/3)*(sigma/(m_e*C))*a - Bfact::T = h_P*C*(L_He_2p-L_He_2s)/k_B + Lalpha::Float64 = 1/L_H_alpha + Lalpha_He::Float64 = 1/L_He_2p + DeltaB::Float64 = h_P*C*(L_H_ion-L_H_alpha) + CDB::Float64 = DeltaB/k_B + DeltaB_He::Float64 = h_P*C*(L_He1_ion-L_He_2s) # 2s, not 2p + CDB_He::Float64 = DeltaB_He/k_B + CB1::Float64 = h_P*C*L_H_ion/k_B + CB1_He1::Float64 = h_P*C*L_He1_ion/k_B # ionization for HeI + CB1_He2::Float64 = h_P*C*L_He2_ion/k_B # ionization for HeII + CR::Float64 = 2π * (m_e/h_P)*(k_B/h_P) + CK::Float64 = Lalpha^3/(8π) + CK_He::Float64 = Lalpha_He^3/(8π) + CL::Float64 = C*h_P/(k_B*Lalpha) + CL_He::Float64 = C*h_P/(k_B/L_He_2s) # comes from det.bal. of 2s-1s + CT::Float64 = (8/3)*(sigma/(m_e*C))*a + Bfact::Float64 = h_P*C*(L_He_2p-L_He_2s)/k_B # Matter departs from radiation when t(Th) > H_frac * t(H) - H_frac::T = 1e-3 # choose some safely small number + H_frac::Float64 = 1e-3 # choose some safely small number # switches Hswitch::Int64 = 1 @@ -115,9 +115,10 @@ end fHe::T = Yp/(not4*(1 - Yp)) # n_He_tot / n_H_tot Nnow::T = 3 * HO * HO * OmegaB / (8π * G * mu_H * m_H) # TODO: should replace during GREAT GENERALIZATION - fu::T = (Hswitch == 0) ? 1.14 : 1.125 - b_He::T = 0.86 # Set the He fudge factor - tol::T = 1e-8 + fu::Float64 = (Hswitch == 0) ? 1.14 : 1.125 + b_He::Float64 = 0.86 # Set the He fudge factor + tol::Float64 = 1e-8 + end # helper constructor which dispatches on the background @@ -501,7 +502,6 @@ end function reionization_Tmat_ode(Tm, rh::RECFASTHistory, z) 𝕣 = rh.𝕣 x_reio = reionization_Xe(rh, z) - a = 1 / (1+z) x_a = a2x(a) Hz = 𝕣.bg.ℋ(x_a) / a / H0_natural_unit_conversion diff --git a/src/perturbations.jl b/src/perturbations.jl index fbc2933..4dd48a5 100644 --- a/src/perturbations.jl +++ b/src/perturbations.jl @@ -12,15 +12,29 @@ struct Hierarchy{T<:Real, PI<:PerturbationIntegrator, CP<:AbstractCosmoParams{T} ih::IH k::Tk ℓᵧ::Int # Boltzmann hierarchy cutoff, i.e. Seljak & Zaldarriaga - ℓ_ν::Int - ℓ_mν::Int - nq::Int end Hierarchy(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::AbstractBackground, ih::AbstractIonizationHistory, k::Real, ℓᵧ=8) = Hierarchy(integrator, par, bg, ih, k, ℓᵧ) +struct Hierarchy_nn{T<:Real, PI<:PerturbationIntegrator, CP<:AbstractCosmoParams{T}, + BG<:AbstractBackground, IH<:AbstractIonizationHistory, Tk<:Real, + AT<:Array{T,1}} +integrator::PI +par::CP +bg::BG +ih::IH +k::Tk +p1::AT #the +p2::AT +ℓᵧ::Int # Boltzmann hierarchy cutoff, i.e. Seljak & Zaldarriaga +end + +Hierarchy_nn(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::AbstractBackground, +ih::AbstractIonizationHistory, k::Real, p1::AbstractArray, p2::AbstractArray,ℓᵧ=8) = Hierarchy(integrator, par, bg, ih, k, p1,p2,ℓᵧ) + + function boltsolve(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6, abstol=1e-6) where T xᵢ = first(hierarchy.bg.x_grid) @@ -32,6 +46,17 @@ function boltsolve(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6, abs return sol end +function boltsolve_nn(hierarchy::Hierarchy_nn{T}, ode_alg=KenCarp4(); reltol=1e-6, abstol=1e-6) where T + xᵢ = first(hierarchy.bg.x_grid) + u₀ = initial_conditions_nn(xᵢ, hierarchy) + prob = ODEProblem{true}(hierarchy_nn!, u₀, (xᵢ , zero(T)), hierarchy) + sol = solve(prob, ode_alg, reltol=reltol, abstol=abstol, + saveat=hierarchy.bg.x_grid, dense=false, + ) + return sol +end + + # basic Newtonian gauge: establish the order of perturbative variables in the ODE solve function unpack(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T @@ -42,6 +67,14 @@ function unpack(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T return Θ, Θᵖ, Φ, δ, v, δ_b, v_b end +function unpack_nn(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T + ℓᵧ = hierarchy.ℓᵧ + Θ = OffsetVector(view(u, 1:(ℓᵧ+1)), 0:ℓᵧ) # indexed 0 through ℓᵧ + Θᵖ = OffsetVector(view(u, (ℓᵧ+2):(2ℓᵧ+2)), 0:ℓᵧ) # indexed 0 through ℓᵧ + Φ, δ, σ, δ_b, v_b = view(u, (2(ℓᵧ+1)+1):(2(ℓᵧ+1)+5)) #getting a little messy... + return Θ, Θᵖ, Φ, δ, σ, δ_b, v_b +end + # BasicNewtonian comes from Callin+06 and the Dodelson textbook (dispatches on hierarchy.integrator) function hierarchy!(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T # compute cosmological quantities at time x, and do some unpacking @@ -96,10 +129,77 @@ function hierarchy!(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] #END RSA - du[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+1:2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*nq+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + return nothing +end + +function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + + α_c = par.α_c + + # Θ, Θᵖ, Φ, δ_c, σ_c,δ_b, v_b = unpack_nn(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_nn(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + + Θ′, Θᵖ′, _, _, _, _, _ = unpack_nn(du, hierarchy) # will be sweetened by .. syntax in 1.6 + + + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + # matter + θ₀,θ₂ = hierarchy.p1,hierarchy.p2 + NN₀,NN₂ = get_nn(θ₀,x,k),get_nn(θ₂,x,k) #get nn objects + δ′ = NN₀ #k / ℋₓ * v - 3Φ′ + # v′ = -v - k / ℋₓ * Ψ + v′ = NN₂ + # σ′ = NN₂ + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + + + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + + # polarized photons + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + + # photon boundary conditions: diffusion damping + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + #END RSA + + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + # du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, σ′, δ_b′, v_b′ # put non-photon perturbations back in return nothing end + # BasicNewtonian Integrator (dispatches on hierarchy.integrator) function initial_conditions(xᵢ, hierarchy::Hierarchy{T, BasicNewtonian}) where T k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih @@ -139,6 +239,53 @@ function initial_conditions(xᵢ, hierarchy::Hierarchy{T, BasicNewtonian}) where return u end + +function initial_conditions_nn(xᵢ, hierarchy::Hierarchy_nn{T, BasicNewtonian}) where T + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + u = zeros(T, 2(ℓᵧ+1)+5) + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(xᵢ), bg.ℋ′(xᵢ), bg.η(xᵢ), ih.τ′(xᵢ), ih.τ′′(xᵢ) + # Θ, Θᵖ, Φ, δ, σ, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + H₀²,aᵢ² = bg.H₀^2,exp(xᵢ)^2 + aᵢ = sqrt(aᵢ²) + #These get a 3/3 since massive neutrinos behave as massless at time of ICs + f_ν = 1/(1 + 1/(7*(3/3)*3.046/8 *(4/11)^(4/3))) # we need to actually keep this + α_c = par.α_c + + # metric and matter perturbations + ℛ = 1.0 # set curvature perturbation to 1 + Φ = (4f_ν + 10) / (4f_ν + 15) * ℛ # for a mode outside the horizon in radiation era + #choosing Φ=1 forces the following value for C, the rest of the ICs follow + C = -( (15 + 4f_ν)/(20 + 8f_ν) ) * Φ + + #trailing (redundant) factors are for converting from MB to Dodelson convention for clarity + Θ[0] = -40C/(15 + 4f_ν) / 4 + Θ[1] = 10C/(15 + 4f_ν) * (k^2 * ηₓ) / (3*k) + Θ[2] = -8k / (15ℋₓ * τₓ′) * Θ[1] + Θᵖ[0] = (5/4) * Θ[2] + Θᵖ[1] = -k / (4ℋₓ * τₓ′) * Θ[2] + Θᵖ[2] = (1/4) * Θ[2] + for ℓ in 3:ℓᵧ + Θ[ℓ] = -ℓ/(2ℓ+1) * k/(ℋₓ * τₓ′) * Θ[ℓ-1] + Θᵖ[ℓ] = -ℓ/(2ℓ+1) * k/(ℋₓ * τₓ′) * Θᵖ[ℓ-1] + end + + + # δ = 3/4 *(4Θ[0]) #the 4 converts δγ_MB -> Dodelson convention + δ = -α_c*Θ[0] # this is general enough to allow this to be any species + δ_b = δ + #we have that Θc = Θb = Θγ = Θν, but need to convert Θ = - k v (i absorbed in v) + # v = -3k*Θ[1] + v = α_c*k*Θ[1] + v_b = -3k*Θ[1] + # σ = 0.0 # this is an actual physical assumption - that DM has no anisotropic stress in Read + #^This is strong but oh well, we can revisit making it a free parameter later + + # u[2(ℓᵧ+1)+1:(2(ℓᵧ+1)+5)] .= Φ, δ, σ, δ_b, v_b # write u with our variables + u[2(ℓᵧ+1)+1:(2(ℓᵧ+1)+5)] .= Φ, δ, v, δ_b, v_b # write u with our variables + return u +end + #FIXME this is pretty old code that hasn't been tested in a while! # TODO: this could be extended to any Newtonian gauge integrator if we specify the # Bardeen potential Ψ and its derivative ψ′ for an integrator, or we saved them diff --git a/src/spectra.jl b/src/spectra.jl index c2f2438..1e3ed1b 100644 --- a/src/spectra.jl +++ b/src/spectra.jl @@ -83,36 +83,20 @@ function cltt(ℓ⃗, par::AbstractCosmoParams, bg, ih, sf) end -function plin(k, 𝕡::AbstractCosmoParams{T},bg,ih, - n_q=15,ℓᵧ=50,ℓ_ν=50,ℓ_mν=20,x=0,reltol=1e-5) where T +function plin(k, 𝕡::AbstractCosmoParams{T},bg,ih,ℓᵧ=15,x=0,reltol=1e-5) where T #copy code abvoe - hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ,ℓ_ν,ℓ_mν,n_q) #shoddy quality test values + hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ) #shoddy quality test values perturb = boltsolve(hierarchy; reltol=reltol) results = perturb(x) - ℳρ,_ = ρ_σ(results[2(ℓᵧ+1)+(ℓ_ν+1)+1:2(ℓᵧ+1)+(ℓ_ν+1)+n_q], - results[2(ℓᵧ+1)+(ℓ_ν+1)+2*n_q+1:2(ℓᵧ+1)+(ℓ_ν+1)+3*n_q], - bg,exp(x),𝕡)./ bg.ρ₀ℳ(x) - #Below assumes negligible neutrino pressure for the normalization (fine at z=0) - ℳθ = k*θ(results[2(ℓᵧ+1)+(ℓ_ν+1)+n_q+1:2(ℓᵧ+1)+(ℓ_ν+1)+2n_q], - bg,exp(x),𝕡)./ bg.ρ₀ℳ(x) #Also using the fact that a=1 at z=0 - δcN,δbN = results[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*n_q+2],results[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*n_q+4] - vcN,vbN = results[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*n_q+3],results[2(ℓᵧ+1)+(ℓ_ν+1)+(ℓ_mν+1)*n_q+5] - ℳρN,ℳθN = ℳρ,ℳθ - vmνN = -ℳθN / k + δcN,δbN = results[2(ℓᵧ+1)+2,:],results[2(ℓᵧ+1)+4,:] + vcN,vbN = results[2(ℓᵧ+1)+3,:],results[2(ℓᵧ+1)+5,:] #omegas to get weighted sum for total matter in background - Tγ = (15/ π^2 *bg.ρ_crit *𝕡.Ω_r)^(1/4) - ζ = 1.2020569 - νfac = (90 * ζ /(11 * π^4)) * (𝕡.Ω_r * 𝕡.h^2 / Tγ) *((𝕡.N_ν/3)^(3/4)) - #^the factor that goes into nr approx to neutrino energy density, plus equal sharing ΔN_eff factor for single massive neutrino - Ω_ν = 𝕡.Σm_ν*νfac/𝕡.h^2 - Ωm = 𝕡.Ω_c+𝕡.Ω_b+Ω_ν + Ωm = 𝕡.Ω_c+𝕡.Ω_b #construct gauge-invariant versions of density perturbations δc = δcN - 3bg.ℋ(x)*vcN /k δb = δbN - 3bg.ℋ(x)*vbN /k - #assume neutrinos fully non-relativistic and can be described by fluid (ok at z=0) - δmν = ℳρN - 3bg.ℋ(x)*vmνN /k - δm = (𝕡.Ω_c*δc + 𝕡.Ω_b*δb + Ω_ν*δmν) / Ωm + δm = (𝕡.Ω_c*δc + 𝕡.Ω_b*δb) / Ωm As=𝕡.A Pprim = As*(k/0.05)^(𝕡.n-1) #pivot scale from Planck (in Mpc^-1) PL= (2π^2 / k^3)*δm^2 *Pprim From 4c5ca1f15c49f78c22117f985b902bc7610a621a Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Thu, 15 Jun 2023 14:26:13 -0700 Subject: [PATCH 03/10] ude delta 10% rel noise --- Project.toml | 28 ++++ scripts/ude_fwddiff_deltac.jl | 287 ++++++++++++++++++++++++++++++++++ src/Bolt.jl | 6 +- src/ionization/ionization.jl | 194 ++++++++--------------- src/perturbations.jl | 187 ++++++++++++++++++++-- 5 files changed, 557 insertions(+), 145 deletions(-) create mode 100644 scripts/ude_fwddiff_deltac.jl diff --git a/Project.toml b/Project.toml index 38dbc7d..d583527 100644 --- a/Project.toml +++ b/Project.toml @@ -4,34 +4,62 @@ authors = ["Zack Li", "Jamie Sullivan"] version = "0.1.0" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypergeometricFunctions = "34004b35-14d8-5ef3-9330-4cdb6864b03a" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" NaturallyUnitful = "872cf16e-200e-11e9-2cdf-8bb39cfbec41" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" NumericalIntegration = "e7bfaba1-d571-5449-8927-abc22e82249b" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f" UnitfulCosmo = "961331e1-62bb-46d9-b9e3-f058129f1391" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/scripts/ude_fwddiff_deltac.jl b/scripts/ude_fwddiff_deltac.jl new file mode 100644 index 0000000..eb2541b --- /dev/null +++ b/scripts/ude_fwddiff_deltac.jl @@ -0,0 +1,287 @@ +using OrdinaryDiffEq +using Random +using LinearAlgebra, Statistics,Lux +rng = Xoshiro(123) +using Bolt +using Plots +using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays +using AbstractDifferentiation +import AbstractDifferentiation as AD, ForwardDiff + +# setup ks won't use all of them here... +L=2f3 +lkmi,lkmax,nk = log10(2.0f0*π/L),log10(0.2f0),8 +kk = 10.0f0.^(collect(lkmi:(lkmax-lkmi)/(nk-1):lkmax)) +ℓᵧ=15 +pertlen=2(ℓᵧ+1)+5 + +# define network +U = Lux.Chain(Lux.Dense(pertlen+1, 8, tanh), #input is u,t + Lux.Dense(8, 8,tanh), + Lux.Dense(8, 2)) +p, st = Lux.setup(rng, U) + +# copy the hierarchy function os it works for nn - for this to work you need the Hierarchy_nn struct and unpack_nn in perturbations.jl +function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + + # the new free cdm index (not used here) + α_c = par.α_c + # get the nn params + p_nn = hierarchy.p + + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_nn(u, hierarchy) + Θ′, Θᵖ′, _, _, _, _, _ = unpack_nn(du, hierarchy) + + # Here I am throwing away the neutrinos entriely, which is probably bad + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c #ignore this + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + # matter + nnin = hcat([u...,x]) + û = U(nnin,p_nn,st)[1] + δ′ = û[1] + v′ = û[2] + # here we implicitly assume σ_c = 0 + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * Bolt.δ_kron(ℓ, 2) / 10) + end + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ + return nothing +end + +# use only the longest k mode +function hierarchy_nnu!(du, u, p, x) + hierarchy = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15); + hierarchy_nn!(du, u, hierarchy, x) +end + +#log loss +function loss(θ) + X̂ = predict(θ) + log(mean(abs2, Ytrain_ode .- X̂) ) +end + + +# adtype = Optimization.AutoZygote() +adtype = Optimization.AutoForwardDiff() +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) +optprob = Optimization.OptimizationProblem(optf, ComponentVector(p)) + + +# some setup +tspan = (-20.0f0, 0.0f0) +𝕡test = CosmoParams(Ω_c=0.3,α_c=-3.0); +bgtest = Background(𝕡test; x_grid=-20.0f0:1f-1:0.0f0); +ihtest = Bolt.get_saha_ih(𝕡test, bgtest); +hierarchytest = Hierarchy(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],15); +hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], ComponentArray(p),15); +u0 = Bolt.initial_conditions_nn(tspan[1],hierarchytestnn) + +# problem for truth and one we will remake +prob_trueode = ODEProblem(Bolt.hierarchy!, u0, tspan,hierarchytest) +prob_nn = ODEProblem(hierarchy_nnu!, u0, (bgtest.x_grid[1],bgtest.x_grid[end]), ComponentArray{Float64}(p)) + +# Generate some noisy data (at reduced solver tolerance) +ode_data = Array(solve(prob_trueode, KenCarp4(), saveat = bgtest.x_grid, + abstol = 1e-3, reltol = 1e-3)) +δ_true,v_true = ode_data[end-3,:],ode_data[end-2,:] +σfakeode = 0.1 +noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode +noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode +Ytrain_ode = hcat([δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v]...) + +# NB I dropped the "Float64" type argument to the ComponentArray, maybe we should put it back? +function predict(θ, T = bgtest.x_grid) + _prob = remake(prob_nn, u0 = u0, tspan = (T[1], T[end]), p = θ) + res = Array(solve(_prob, KenCarp4(), saveat = T, + abstol = 1e-3, reltol = 1e-3)) + return hcat(res[end-3,:],res[end-2,:]) +end + +# test the prediction gradient +ab = AD.ForwardDiffBackend() +AD.jacobian(ab,predict,ComponentArray(p)) + + +# Training +losses = []; +callback = function (p, l) + push!(losses, l) + # if length(losses) % 50 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + # end + return false +end +niter=2#80 +res1 = Optimization.solve(optprob, ADAM(1.0), callback = callback, maxiters = niter) +# this is pretty slow, on the order of half an hour, but it is at least running! +# Now idk what is wrong with reverse mode...but getting errors about typing... + +# Get the result +test_predict_o1 = predict(res1.u) + + +# Plots of the learned perturbations +Plots.scatter(bgtest.x_grid,Ytrain_ode[:,1],label="data") +Plots.plot!(bgtest.x_grid,δ_true,label="truth",yscale=:log10,lw=2.5) +Plots.plot!(bgtest.x_grid,test_predict_o1[:,1],label="opt-v1",lw=2.5,ls=:dash) +Plots.title!(raw"$\delta_{c}$") +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$\delta_{c}(a)$") +savefig("../plots/deltac_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") + +Plots.scatter(bgtest.x_grid,Ytrain_ode[:,2],label="data",yscale=:log10,legend=:bottomright) +Plots.plot!(bgtest.x_grid,v_true,label="truth") +Plots.plot!(bgtest.x_grid,test_predict_o1[:,2],label="opt-v1") +Plots.title!(raw"$v_{c}$") +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$v_{c}(a)$") +savefig("../plots/vc_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") + + +function get_Φ′_Ψ(u,hierarchy::Hierarchy{T},x) where T + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + α_c = par.α_c + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + return Φ′,Ψ +end + +# The reconstructed function +Φ′_true,Ψ_true = zeros(length(hierarchytest.bg.x_grid)),zeros(length(hierarchytest.bg.x_grid)) +nn_δ′,nn_v′ = zeros(length(hierarchytest.bg.x_grid)),zeros(length(hierarchytest.bg.x_grid)) +for j in 1:length(hierarchytest.bg.x_grid) + Φ′_true[j],Ψ_true[j] = get_Φ′_Ψ(ode_data[:,j],hierarchytest,hierarchytest.bg.x_grid[j]) + nnin = hcat([ode_data[:,j]...,hierarchytest.bg.x_grid[j]]) + nn_u′ = U(nnin,res1.u,st)[1] + nn_δ′[j],nn_v′[j] = nn_u′[1], nn_u′[2] +end + + +true_δ′ = hierarchytestnn.k ./ hierarchytestnn.bg.ℋ .* v_true .- 3Φ′_true +true_v′ = -v_true .- hierarchytestnn.k ./ hierarchytestnn.bg.ℋ .* Ψ_true + +Plots.plot(hierarchytestnn.bg.x_grid, true_δ′,label="truth",lw=2.5) +Plots.plot!(hierarchytestnn.bg.x_grid,nn_δ′,label="nn-v1",lw=2.5) +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$\delta'(a)$") +Plots.title!(raw"recon $v_{c}$") +savefig("../plots/deltacprime_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") + +Plots.plot(hierarchytestnn.bg.x_grid,true_v′,label="truth",lw=2.5) +Plots.plot!(hierarchytestnn.bg.x_grid,nn_v′,label="nn-v1",lw=2.5) +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$v'(a)$") +Plots.title!(raw"recon $v_{c}$") +savefig("../plots/vcprime_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") + +# -------------------------------- +# Old code testing AD backends + +# # try swapping out non-Enzyme AD backends as Marius suggested: +# using Zygote +# Zygote.gradient(loss,p) +# Zygote.gradient(loss, ComponentArray{Float64}(p) ) + + +# using Enzyme +# autodiff(Reverse, loss, Active, Active(p)) +# autodiff(Reverse, loss, Active, Active( ComponentArray{Float64}(p) )) +# autodiff(Forward, loss, Duplicated, Const(p)) + + +# ab = AD.ZygoteBackend() +# f(x) = log(sum(exp, x)) +# AD.gradient(ab, f, rand(10)) +# # Zygote +# AD.gradient(ab,loss,ComponentArray(p)) + + +# # RevserseDiff +# import ReverseDiff +# ab = AD.ReverseDiffBackend() +# AD.gradient(ab, f, rand(10)) +# # AD.gradient(ab,loss,p) +# AD.gradient(ab,loss,ComponentArray(p)) + +# # Tracker +# import Tracker +# ab = AD.TrackerBackend() +# AD.gradient(ab, f, rand(10)) +# AD.gradient(ab,loss,p) + +# # ForwardDiff +# import ForwardDiff +# ab = AD.ForwardDiffBackend() +# AD.gradient(ab, f, rand(10)) +# AD.gradient(ab,loss,p) + +# loss(ComponentArray(p)) +# AD.gradient(ab,loss,ComponentArray(p)) +# #this at least should work? why not? + +# #ok so why is the gradient zero? Is predictor gradient also zero? +# AD.jacobian(ab,predict,ComponentArray(p)) +# # jacobian, but yes. Why is it zero? + +# AD.jacobian(ab,predict,ComponentArray(p)) + +# typeof(ComponentArray(p)) + +# # Try fwd diff optimizing... + + + +# # FiniteDifferences +# import FiniteDifferences +# ab = AD.FiniteDifferencesBackend() +# AD.gradient(ab, f, rand(10)) +# AD.gradient(ab,loss,p) + +# # check that forwarddiff works because it did before... +# ForwardDiff.jacobian(loss,p) +# typeof(p) + diff --git a/src/Bolt.jl b/src/Bolt.jl index b5d72b1..269d85c 100644 --- a/src/Bolt.jl +++ b/src/Bolt.jl @@ -5,10 +5,11 @@ export Background, AbstractBackground export IonizationHistory, AbstractIonizationHistory, IonizationIntegrator export Peebles, PeeblesI,ihPeebles export ρ_σ,ρP_0,f0,dlnf0dlnq,θ,oldH_a,H_a #FIXME: quick hack to look at perts -export Hierarchy, boltsolve, BasicNewtonian,unpack,rsa_perts!,boltsolve_rsa +export Hierarchy, unpack_nn, Hierarchy_nn,initial_conditions_nn,hierarchy_nn!,boltsolve,boltsolve_nn, BasicNewtonian,unpack,rsa_perts!,boltsolve_rsa +export Hierarchy_spl,hierarchy_spl!,unpack_spl,boltsolve_spl,Hierarchy_nn export IE,initial_conditions,unpack,ie_unpack export source_grid, quadratic_k, cltt,log10_k,plin -export z2a, a2z, x2a, a2x, z2x, x2z, to_ui, from_ui, dxdq +export z2a, a2z, x2a, a2x, z2x, x2z, to_ui, from_ui, dxdq,δ_kron using Parameters using Unitful, UnitfulAstro @@ -25,6 +26,7 @@ using StaticArrays using DoubleFloats using MuladdMacro using LinearAlgebra +using SimpleChains,Random using FFTW diff --git a/src/ionization/ionization.jl b/src/ionization/ionization.jl index 03a23d5..0e798d3 100644 --- a/src/ionization/ionization.jl +++ b/src/ionization/ionization.jl @@ -5,16 +5,66 @@ const PeeblesT₀ = ustrip(natural(2.725u"K")) # CMB temperature [K] # TODO: make this a parameter of the ionization n_b(a, par) = par.Ω_b * ρ_crit(par) / (m_H * a^3) n_H(a, par) = n_b(a, par) *(1-par.Y_p) #Adding Helium! -saha_T_b(a) = PeeblesT₀ / a #j why does this take par? -saha_rhs(a, par) = (m_e * saha_T_b(a) / 2π)^(3/2) / n_H(a, par) * - exp(-ε₀_H / saha_T_b(a)) # rhs of Callin06 eq. 12 - -function saha_Xₑ(x, par::AbstractCosmoParams) - rhs = saha_rhs(x2a(x), par) - return (√(rhs^2 + 4rhs) - rhs) / 2 # solve Xₑ² / (1-Xₑ) = RHS, it's a polynomial +saha_T_b(a, par) = PeeblesT₀ / a #j why does this take par? +saha_rhs(a, par) = (m_e * saha_T_b(a, par) / 2π)^(3/2) / n_H(a, par) * + exp(-ε₀_H / saha_T_b(a, par)) # rhs of Callin06 eq. 12 + +function saha_Xₑ(x, par::AbstractCosmoParams, bg) + z = x2z(x) + CR = 1.7998756579640975e14 + CB1 = 157802.38230335814 + G = 6.6742e-11 + m_H = 1.673575e-27 + Yp = 0.24 + mu_H = 1 / (1 - Yp) + Tnow = (15/ π^2 *bg.ρ_crit * par.Ω_r)^(1/4) * Kelvin_natural_unit_conversion + HO = bg.H₀ / H0_natural_unit_conversion + OmegaB=par.Ω_b + Nnow = 3 * HO * HO * OmegaB / (8π * G * mu_H * m_H) + sqrtrhs = exp(1.5 * log(CR*Tnow/(1+z))/2 - CB1/(Tnow*(1+z))/2) / sqrt(Nnow) + return 2sqrtrhs / (√(sqrtrhs^2 + 4) + sqrtrhs) # this form is more stable end saha_Xₑ(par) = (x -> saha_Xₑ(x, par)) +function get_saha_ih(par::AbstractCosmoParams{T}, bg::AbstractBackground{T}) where T + Xₑ_function(x) = saha_Xₑ(x, par, bg) + OmegaG = par.Ω_r + Tnow = (15/ π^2 *bg.ρ_crit * OmegaG)^(1/4) * Kelvin_natural_unit_conversion + Tmat_function(x) = Tnow / exp(x) + C = 2.99792458e8 + k_B = 1.380658e-23 + m_H = 1.673575e-27 + not4 = 3.9715e0 + Yp = 0.24 + mu_T = not4/(not4-(not4-1)*Yp) + csb²_pre(x) = C^-2 * k_B/m_H * ( 1/mu_T + (1-Yp)*Xₑ_function(x) ) + csb²_function(x) = csb²_pre(x) * (Tmat_function(x) - 1/3 * (-Tmat_function(x))) + + ℋ_function = bg.ℋ + x_grid = bg.x_grid + τ, τ′ = τ_functions(x_grid, Xₑ_function, par, ℋ_function) + g̃ = g̃_function(τ, τ′) + + Xₑ_ = spline(Xₑ_function.(x_grid), x_grid) + τ_ = spline(τ.(x_grid), x_grid) + g̃_ = spline(g̃.(x_grid), x_grid) + Tmat_ = spline(Tmat_function.(x_grid), x_grid) + csb²_ = spline(csb²_function.(x_grid), x_grid) + + return IonizationHistory( + T(τ(0.)), + Xₑ_, + τ_, + spline_∂ₓ(τ_, x_grid), + spline_∂ₓ²(τ_, x_grid), + g̃_, + spline_∂ₓ(g̃_, x_grid), + spline_∂ₓ²(g̃_, x_grid), + Tmat_, + csb²_, + ) +end + ## Peebles Equation ## Use this for Xₑ < 0.99, i.e. z < 1587.4 @@ -27,17 +77,6 @@ const m_H = ustrip(natural(float(ProtonMass))) const α = ustrip(natural(float(FineStructureConstant))) const σ_T = ustrip(natural(float(ThomsonCrossSection))) -const C_rf = 2.99792458e8 -const k_B_rf = 1.380658e-23 -const m_H_rf = 1.673575e-27 -const not4_rf = 3.9715e0 -const xinitial_RECFAST = z2x(10000.0) -const sigma = 6.6524616e-29 -const m_e_rf = 9.1093897e-31 -const zre_ini = 50.0 -const tol_rf = 1e-8 -# const Kelvin_natural_unit_conversion = # this is defined in recfast - # auxillary equations ϕ₂(T_b) = 0.448 * log(ε₀_H / T_b) α⁽²⁾(T_b) = (64π / √(27π)) * (α^2 / m_e^2) * √(ε₀_H / T_b) * ϕ₂(T_b) @@ -45,7 +84,7 @@ const tol_rf = 1e-8 β⁽²⁾(T_b) = β(T_b) * exp(3ε₀_H / 4T_b) n₁ₛ(a, Xₑ, par) = (1 - Xₑ) * n_H(a, par) #Problem is here \/ since Lyα rate is given by redshifting out of line need H -Λ_α(a, Xₑ, par) = H_a(a, par) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) +Λ_α(a, Xₑ, par) = oldH_a(a, par) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) new_Λ_α(a, Xₑ, par, ℋ_function) = ℋ_function(log(a)) * (3ε₀_H)^3 / ((8π)^2 * n₁ₛ(a, Xₑ, par)) Cᵣ(a, Xₑ, T_b, par) = (Λ_2s_to_1s + Λ_α(a, Xₑ, par)) / ( Λ_2s_to_1s + Λ_α(a, Xₑ, par) + β⁽²⁾(T_b)) @@ -53,10 +92,10 @@ new_Cᵣ(a, Xₑ, T_b, par,ℋ_function) = (Λ_2s_to_1s + new_Λ_α(a, Xₑ, par Λ_2s_to_1s + new_Λ_α(a, Xₑ, par,ℋ_function) + β⁽²⁾(T_b)) # RHS of Callin06 eq. 13 -function peebles_Xₑ′(Xₑ, par::CosmoParams{T}, x) where T +function peebles_Xₑ′(Xₑ, par, x) a = exp(x) - T_b_a = BigFloat(saha_T_b(a)) # handle overflows by switching to bigfloat - return T(Cᵣ(a, Xₑ, T_b_a, par) / H_a(a, par) * ( + T_b_a = BigFloat(saha_T_b(a, par)) # handle overflows by switching to bigfloat + return float(Cᵣ(a, Xₑ, T_b_a, par) / oldH_a(a, par) * ( β(T_b_a) * (1 - Xₑ) - n_H(a, par) * α⁽²⁾(T_b_a) * Xₑ^2)) end @@ -140,7 +179,7 @@ function τ′(x, Xₑ_function, par, ℋ_function) end function oldτ′(x, Xₑ_function, par) a = x2a(x) - return -Xₑ_function(x) * n_H(a, par) * a * σ_T / (a*H_a(a,par)) + return -Xₑ_function(x) * n_H(a, par) * a * σ_T / (a*oldH_a(a,par)) end function g̃_function(τ_x_function, τ′_x_function) @@ -149,7 +188,7 @@ end """Convenience function to create an ionisation history from some tables""" -function customion(par, bg, Xₑ_function, Tmat_function, csb²_function) +function customion(par::AbstractCosmoParams{T}, bg, Xₑ_function, Tmat_function, csb²_function) where T x_grid = bg.x_grid τ, τ′ = Bolt.τ_functions(x_grid, Xₑ_function, par, bg.ℋ) g̃ = Bolt.g̃_function(τ, τ′) @@ -157,11 +196,11 @@ function customion(par, bg, Xₑ_function, Tmat_function, csb²_function) Xₑ_ = spline(Xₑ_function.(x_grid), x_grid) τ_ = spline(τ.(x_grid), x_grid) g̃_ = spline(g̃.(x_grid), x_grid) - Tmat_ = spline(Tmat_function.(x_grid), x_grid) + Tmat_ = spline(T.(Tmat_function.(x_grid)), x_grid) csb²_ = spline(csb²_function.(x_grid), x_grid) return IonizationHistory( - (τ(0.)), + T(τ(0.)), Xₑ_, τ_, spline_∂ₓ(τ_, x_grid), @@ -173,106 +212,3 @@ function customion(par, bg, Xₑ_function, Tmat_function, csb²_function) csb²_, ) end - -#----------------------- - -function reionization_Xe(𝕡::CosmoParams, Xe_func, z) - X_fin = 1 + 𝕡.Y_p / ( not4_rf*(1-𝕡.Y_p) ) #ionization frac today - zre,α,ΔH,zHe,ΔHe,fHe = 7.6711,1.5,0.5,3.5,0.5,X_fin-1 #reion params, TO REPLACE - x_orig = Xe_func(z2x(z)) - x_reio_H = (X_fin - x_orig) / 2 * ( - 1 + tanh(( (1+zre)^α - (1+z)^α ) / ( α*(1+zre)^(α-1) ) / ΔH)) + x_orig - x_reio_He = fHe / 2 * ( 1 + tanh( (zHe - z) / ΔHe) ) - x_reio = x_reio_H + x_reio_He - return x_reio -end - -function reionization_Tmat_ode(Tm,z) - x_reio = reionization_Xe(𝕡, Xe_func,z) - a = 1 / (1+z) - x_a = a2x(a) - Hz = bg.ℋ(x_a) / a / H0_natural_unit_conversion - Trad =Tnow_rf * (1+z) - CT_rf = (8/3)*(sigma/(m_e_rf*C_rf))*a - fHe = 𝕡.Y_p/(not4_rf*(1 - 𝕡.Y_p)) - dTm = CT_rf * Trad^4 * x_reio/(1 + x_reio + fHe) * - (Tm - Trad) / (Hz * (1 + z)) + 2 * Tm / (1 + z) - return dTm -end - -function tanh_reio_solve(Tmat0; zre_ini=50.0,zfinal=0.0) - reio_prob = ODEProblem(reionization_Tmat_ode, - Tmat0, - (zre_ini, zfinal)) - sol_reio_Tmat = solve(reio_prob, Tsit5(), reltol=tol_rf) - trh = TanhReionizationHistory(zre_ini, ion_hist, sol_reio_Tmat); - return trh -end - - -# struct Peebles_hist{T, AB<:AbstractBackground{T},CT<:AbstractCosmoParams{T}} <: IonizationIntegrator -# par::CT -# bg::AB -# Xe -# end - -function ihPeebles(par::AbstractCosmoParams{T}, bg::AbstractBackground{T};zfinal=0.0) where T - - x_grid = bg.x_grid - Xₑ_function = saha_peebles_recombination(par) - τ, τ′ = τ_functions(x_grid, Xₑ_function, par, bg.ℋ) - g̃ = Bolt.g̃_function(τ, τ′) - spline, spline_∂ₓ, spline_∂ₓ² = Bolt.spline, Bolt.spline_∂ₓ, Bolt.spline_∂ₓ² - Xₑ_ = spline(Xₑ_function.(x_grid), x_grid) - τ_ = spline(τ.(x_grid), x_grid) - g̃_ = spline(g̃.(x_grid), x_grid) - - Tnow_rf = (15/ π^2 *bg.ρ_crit * par.Ω_r)^(1/4) * Kelvin_natural_unit_conversion #last thing is natural to K - Trad_function = x -> Tnow_rf * (1 + x2z(x)) - - # trhist = tanh_reio_solve(rhist) - Tmat0=Trad_function(xinitial_RECFAST) #FIXME CHECK - - function reionization_Tmat_ode(Tm,p,z) - x_reio = reionization_Xe(par, Xₑ_,z) - Trad = Tnow_rf * (1 + z) - Hz = bg.ℋ(z2x(z)) * (1 + z) / H0_natural_unit_conversion - a=z2a(z) - CT_rf = (8/3)*(sigma/(m_e_rf*C_rf))*a - fHe = par.Y_p/(not4_rf*(1 - par.Y_p)) - return CT_rf * Trad^4 * x_reio/(1 + x_reio + fHe) * - (Tm - Trad) / (Hz * (1 + z)) + 2 * Tm / (1 + z) - end - zre_ini=50.0 - reio_prob = ODEProblem(reionization_Tmat_ode, - Tmat0, - (zre_ini, zfinal)) - sol_reio_Tmat = solve(reio_prob, Tsit5(), reltol=tol_rf) - # trh = TanhReionizationHistory(zre_ini, ion_hist, sol_reio_Tmat) - - Tmat_function = x -> (x < z2x(zre_ini)) ? - Trad_function(x) : sol_reio_Tmat(x2z(x)) - - Tmat_ = spline(Tmat_function.(x_grid), x_grid) - Yp = par.Y_p - mu_T_rf = not4_rf/(not4_rf-(not4_rf-1)*Yp) - csb²_pre = @.( C_rf^-2 * k_B_rf/m_H_rf * ( 1/mu_T_rf + (1-Yp)*Xₑ_(x_grid) ) ) #not the most readable... - #FIXME probably this is a bad way to do this... - csb²_ = spline(csb²_pre .* (Tmat_.(x_grid) .- 1/3 *spline_∂ₓ(Tmat_, x_grid).(x_grid)),x_grid) - # csb²_ = spline(csb²_function.(x_grid), x_grid) - - # println("typeof(Xₑ_) $(typeof(Xₑ_))") - # println("typeof(τ_) $(typeof(τ_))") - return IonizationHistory( - T(τ(0.)), - Xₑ_, - τ_, - spline_∂ₓ(τ_, x_grid), - spline_∂ₓ²(τ_, x_grid), - g̃_, - spline_∂ₓ(g̃_, x_grid), - spline_∂ₓ²(g̃_, x_grid), - Tmat_, - csb²_, - ) -end diff --git a/src/perturbations.jl b/src/perturbations.jl index 4dd48a5..2c0ef81 100644 --- a/src/perturbations.jl +++ b/src/perturbations.jl @@ -20,21 +20,41 @@ Hierarchy(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::Abst struct Hierarchy_nn{T<:Real, PI<:PerturbationIntegrator, CP<:AbstractCosmoParams{T}, BG<:AbstractBackground, IH<:AbstractIonizationHistory, Tk<:Real, - AT<:Array{T,1}} + # S<:Real,AT<:AbstractArray{S,1} + } integrator::PI par::CP bg::BG ih::IH k::Tk -p1::AT #the -p2::AT +# p::AT #the +p::AbstractArray #the ℓᵧ::Int # Boltzmann hierarchy cutoff, i.e. Seljak & Zaldarriaga end Hierarchy_nn(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::AbstractBackground, -ih::AbstractIonizationHistory, k::Real, p1::AbstractArray, p2::AbstractArray,ℓᵧ=8) = Hierarchy(integrator, par, bg, ih, k, p1,p2,ℓᵧ) +ih::AbstractIonizationHistory, k::Real, p::AbstractArray,ℓᵧ=8 +) = Hierarchy_nn(integrator, par, bg, ih, k, p,ℓᵧ) +struct Hierarchy_spl{T<:Real, PI<:PerturbationIntegrator, CP<:AbstractCosmoParams{T}, + BG<:AbstractBackground, IH<:AbstractIonizationHistory, Tk<:Real, + S<:Real,IT<:AbstractInterpolation{S,1}} +integrator::PI +par::CP +bg::BG +ih::IH +k::Tk +spl1::IT +# spl2::IT +ℓᵧ::Int # Boltzmann hierarchy cutoff, i.e. Seljak & Zaldarriaga +end + +Hierarchy_spl(integrator::PerturbationIntegrator, par::AbstractCosmoParams, bg::AbstractBackground, +ih::AbstractIonizationHistory, k::Real, +spl1::AbstractInterpolation,spl2::AbstractInterpolation,ℓᵧ=8 +) = Hierarchy_spl(integrator, par, bg, ih, k, spl1,spl2,ℓᵧ) + function boltsolve(hierarchy::Hierarchy{T}, ode_alg=KenCarp4(); reltol=1e-6, abstol=1e-6) where T xᵢ = first(hierarchy.bg.x_grid) @@ -51,12 +71,20 @@ function boltsolve_nn(hierarchy::Hierarchy_nn{T}, ode_alg=KenCarp4(); reltol=1e- u₀ = initial_conditions_nn(xᵢ, hierarchy) prob = ODEProblem{true}(hierarchy_nn!, u₀, (xᵢ , zero(T)), hierarchy) sol = solve(prob, ode_alg, reltol=reltol, abstol=abstol, - saveat=hierarchy.bg.x_grid, dense=false, + saveat=hierarchy.bg.x_grid, ) return sol end - +function boltsolve_spl(hierarchy::Hierarchy_spl{T}, saveat::Array{T,1},ode_alg=KenCarp4(); reltol=1e-6, abstol=1e-6) where T + xᵢ = first(hierarchy.bg.x_grid) + u₀ = initial_conditions_spl(xᵢ, hierarchy) + prob = ODEProblem{true}(hierarchy_spl!, u₀, (xᵢ , zero(T)), hierarchy) + sol = solve(prob, ode_alg, reltol=reltol, abstol=abstol, + saveat=saveat, + ) + return sol +end # basic Newtonian gauge: establish the order of perturbative variables in the ODE solve function unpack(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T @@ -67,12 +95,20 @@ function unpack(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T return Θ, Θᵖ, Φ, δ, v, δ_b, v_b end -function unpack_nn(u, hierarchy::Hierarchy{T, BasicNewtonian}) where T +function unpack_nn(u, hierarchy::Hierarchy_nn{T, BasicNewtonian}) where T ℓᵧ = hierarchy.ℓᵧ Θ = OffsetVector(view(u, 1:(ℓᵧ+1)), 0:ℓᵧ) # indexed 0 through ℓᵧ Θᵖ = OffsetVector(view(u, (ℓᵧ+2):(2ℓᵧ+2)), 0:ℓᵧ) # indexed 0 through ℓᵧ - Φ, δ, σ, δ_b, v_b = view(u, (2(ℓᵧ+1)+1):(2(ℓᵧ+1)+5)) #getting a little messy... - return Θ, Θᵖ, Φ, δ, σ, δ_b, v_b + Φ, δ, v, δ_b, v_b = view(u, (2(ℓᵧ+1)+1):(2(ℓᵧ+1)+5)) #getting a little messy... + return Θ, Θᵖ, Φ, δ, v, δ_b, v_b +end + +function unpack_spl(u, hierarchy::Hierarchy_spl{T, BasicNewtonian}) where T + ℓᵧ = hierarchy.ℓᵧ + Θ = OffsetVector(view(u, 1:(ℓᵧ+1)), 0:ℓᵧ) # indexed 0 through ℓᵧ + Θᵖ = OffsetVector(view(u, (ℓᵧ+2):(2ℓᵧ+2)), 0:ℓᵧ) # indexed 0 through ℓᵧ + Φ, δ, v, δ_b, v_b = view(u, (2(ℓᵧ+1)+1):(2(ℓᵧ+1)+5)) #getting a little messy... + return Θ, Θᵖ, Φ, δ, v, δ_b, v_b end # BasicNewtonian comes from Callin+06 and the Dodelson textbook (dispatches on hierarchy.integrator) @@ -133,6 +169,17 @@ function hierarchy!(du, u, hierarchy::Hierarchy{T, BasicNewtonian}, x) where T return nothing end +function get_nn(m,d_in,d_out;rng=Xoshiro(123)) + NN = SimpleChain(static(d_in), + TurboDense{true}(tanh, m), + TurboDense{true}(tanh, m), + TurboDense{false}(identity, d_out) #have not tested non-scalar output + ); + p = SimpleChains.init_params(NN;rng); + G = SimpleChains.alloc_threaded_grad(NN); + return NN +end + function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T # compute cosmological quantities at time x, and do some unpacking k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih @@ -141,6 +188,7 @@ function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) whe a = x2a(x) R = 4Ω_r / (3Ω_b * a) csb² = ih.csb²(x) + pertlen=2(ℓᵧ+1)+5 α_c = par.α_c @@ -162,12 +210,82 @@ function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) whe ) # matter - θ₀,θ₂ = hierarchy.p1,hierarchy.p2 - NN₀,NN₂ = get_nn(θ₀,x,k),get_nn(θ₂,x,k) #get nn objects - δ′ = NN₀ #k / ℋₓ * v - 3Φ′ + # THIS DOESN'T WORK!!! + NN = get_nn(4,pertlen+2,32) + nnin = hcat([u...,k,x]) + u′ = NN(nnin,hierarchy.p) + δ′ = u′[1] #k / ℋₓ * v - 3Φ′ # v′ = -v - k / ℋₓ * Ψ - v′ = NN₂ + v′ = u′[2] + # σ′ = NN₂ + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + + + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + + # polarized photons + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + + # photon boundary conditions: diffusion damping + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + #END RSA + + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + # du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, σ′, δ_b′, v_b′ # put non-photon perturbations back in + return nothing +end + + +function hierarchy_spl!(du, u, hierarchy::Hierarchy_spl{T, BasicNewtonian}, x) where T + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + pertlen=2(ℓᵧ+1)+5 + + α_c = par.α_c + + # Θ, Θᵖ, Φ, δ_c, σ_c,δ_b, v_b = unpack_nn(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_spl(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + + Θ′, Θᵖ′, _, _, _, _, _ = unpack_spl(du, hierarchy) # will be sweetened by .. syntax in 1.6 + + + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + # matter + δ′ = hierarchy.spl1(x) + v′ = -v - k / ℋₓ * Ψ + # v′ = hierarchy.spl2(x) # σ′ = NN₂ + #FIXME we don't actually need to learn v in the ODE, just a scalar at z=0... + #Well this depends on the assumption we make I think... δ_b′ = k / ℋₓ * v_b - 3Φ′ v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) @@ -201,6 +319,7 @@ end # BasicNewtonian Integrator (dispatches on hierarchy.integrator) + function initial_conditions(xᵢ, hierarchy::Hierarchy{T, BasicNewtonian}) where T k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih u = zeros(T, 2(ℓᵧ+1)+5) @@ -245,7 +364,7 @@ function initial_conditions_nn(xᵢ, hierarchy::Hierarchy_nn{T, BasicNewtonian}) u = zeros(T, 2(ℓᵧ+1)+5) ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(xᵢ), bg.ℋ′(xᵢ), bg.η(xᵢ), ih.τ′(xᵢ), ih.τ′′(xᵢ) # Θ, Θᵖ, Φ, δ, σ, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) - Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack_nn(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) H₀²,aᵢ² = bg.H₀^2,exp(xᵢ)^2 aᵢ = sqrt(aᵢ²) #These get a 3/3 since massive neutrinos behave as massless at time of ICs @@ -286,6 +405,46 @@ function initial_conditions_nn(xᵢ, hierarchy::Hierarchy_nn{T, BasicNewtonian}) return u end +function initial_conditions_spl(xᵢ, hierarchy::Hierarchy_spl{T, BasicNewtonian}) where T + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + u = zeros(T, 2(ℓᵧ+1)+5) + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(xᵢ), bg.ℋ′(xᵢ), bg.η(xᵢ), ih.τ′(xᵢ), ih.τ′′(xᵢ) + # Θ, Θᵖ, Φ, δ, σ, δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack_spl(u, hierarchy) # the Θ, Θᵖ are mutable views (see unpack) + H₀²,aᵢ² = bg.H₀^2,exp(xᵢ)^2 + aᵢ = sqrt(aᵢ²) + #These get a 3/3 since massive neutrinos behave as massless at time of ICs + f_ν = 1/(1 + 1/(7*(3/3)*3.046/8 *(4/11)^(4/3))) # we need to actually keep this + α_c = par.α_c + + # metric and matter perturbations + ℛ = 1.0 # set curvature perturbation to 1 + Φ = (4f_ν + 10) / (4f_ν + 15) * ℛ # for a mode outside the horizon in radiation era + #choosing Φ=1 forces the following value for C, the rest of the ICs follow + C = -( (15 + 4f_ν)/(20 + 8f_ν) ) * Φ + + #trailing (redundant) factors are for converting from MB to Dodelson convention for clarity + Θ[0] = -40C/(15 + 4f_ν) / 4 + Θ[1] = 10C/(15 + 4f_ν) * (k^2 * ηₓ) / (3*k) + Θ[2] = -8k / (15ℋₓ * τₓ′) * Θ[1] + Θᵖ[0] = (5/4) * Θ[2] + Θᵖ[1] = -k / (4ℋₓ * τₓ′) * Θ[2] + Θᵖ[2] = (1/4) * Θ[2] + for ℓ in 3:ℓᵧ + Θ[ℓ] = -ℓ/(2ℓ+1) * k/(ℋₓ * τₓ′) * Θ[ℓ-1] + Θᵖ[ℓ] = -ℓ/(2ℓ+1) * k/(ℋₓ * τₓ′) * Θᵖ[ℓ-1] + end + + δ = -α_c*Θ[0] # this is general enough to allow this to be any species + δ_b = δ + # v = -3k*Θ[1] + v = α_c*k*Θ[1] + v_b = -3k*Θ[1] + + u[2(ℓᵧ+1)+1:(2(ℓᵧ+1)+5)] .= Φ, δ, v, δ_b, v_b # write u with our variables + return u +end + #FIXME this is pretty old code that hasn't been tested in a while! # TODO: this could be extended to any Newtonian gauge integrator if we specify the # Bardeen potential Ψ and its derivative ψ′ for an integrator, or we saved them From 9211bdec0dc7e5f9b25c1653dd55221f3eb5f127 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Thu, 15 Jun 2023 19:13:13 -0700 Subject: [PATCH 04/10] loss plot and more opt, hacky recon plot though --- scripts/ude_fwddiff_deltac.jl | 62 +++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/scripts/ude_fwddiff_deltac.jl b/scripts/ude_fwddiff_deltac.jl index eb2541b..d3d9d34 100644 --- a/scripts/ude_fwddiff_deltac.jl +++ b/scripts/ude_fwddiff_deltac.jl @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics,Lux rng = Xoshiro(123) using Bolt using Plots -using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays +using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays,OptimizationOptimJL using AbstractDifferentiation import AbstractDifferentiation as AD, ForwardDiff @@ -136,36 +136,61 @@ AD.jacobian(ab,predict,ComponentArray(p)) losses = []; callback = function (p, l) push!(losses, l) - # if length(losses) % 50 == 0 + if length(losses) % 5 == 0 println("Current loss after $(length(losses)) iterations: $(losses[end])") - # end + end return false end -niter=2#80 -res1 = Optimization.solve(optprob, ADAM(1.0), callback = callback, maxiters = niter) +niter_1,niter_2,niter_3 =50,50,20 +η_1,η_2,η_3 = 1.0,0.1,0.01 +res1 = Optimization.solve(optprob, ADAM(1.0), callback = callback, maxiters = niter_1) # this is pretty slow, on the order of half an hour, but it is at least running! # Now idk what is wrong with reverse mode...but getting errors about typing... - +#Current loss after 50 iterations: 3.8709838260084415 # Get the result test_predict_o1 = predict(res1.u) +optprob2 = remake(optprob,u0 = res1.u); +res2 = Optimization.solve(optprob2, ADAM(η_2), callback = callback, maxiters = niter_2) +test_predict_o2 = predict(res2.u) +#Current loss after 100 iterations: 1.1319215191941459 + +optprob3 = remake(optprob2,u0 = res2.u); +res3 = Optimization.solve(optprob3, ADAM(η_3), callback = callback, maxiters = niter_3) +test_predict_o3 = predict(res3.u) +#Current loss after 120 iterations: 1.0862281116475199 + +# FIXME MORE INVOLVED ADAM SCHEDULE + +# for the heck of it try BFGS, not sure what parameters it usually takes +optprob4 = remake(optprob3,u0 = res3.u); +res4 = Optimization.solve(optprob3, BFGS(), + callback = callback, maxiters = 10) +test_predict_o4 = predict(res4.u) +# wow this is slow...I guess due to Hessian approximation? +# We have the gradient so idk why this takes so much longer than ADAM? +# Somehow loss actually goes up also? Maybe we overshoot, can try again with specified initial stepsize... # Plots of the learned perturbations Plots.scatter(bgtest.x_grid,Ytrain_ode[:,1],label="data") Plots.plot!(bgtest.x_grid,δ_true,label="truth",yscale=:log10,lw=2.5) Plots.plot!(bgtest.x_grid,test_predict_o1[:,1],label="opt-v1",lw=2.5,ls=:dash) +Plots.plot!(bgtest.x_grid,test_predict_o4[:,1],label="opt-v1-full",lw=2.5,ls=:dot) Plots.title!(raw"$\delta_{c}$") Plots.xlabel!(raw"$\log(a)$") Plots.ylabel!(raw"$\delta_{c}(a)$") -savefig("../plots/deltac_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") +savefig("../plots/deltac_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") -Plots.scatter(bgtest.x_grid,Ytrain_ode[:,2],label="data",yscale=:log10,legend=:bottomright) +log10.(test_predict_o4[:,2]) + +Plots.scatter(bgtest.x_grid,Ytrain_ode[:,2],label="data")#,legend=:bottomright) Plots.plot!(bgtest.x_grid,v_true,label="truth") Plots.plot!(bgtest.x_grid,test_predict_o1[:,2],label="opt-v1") +Plots.plot!(bgtest.x_grid,test_predict_o4[:,2],label="opt-v1-full",lw=2.5,ls=:dot) Plots.title!(raw"$v_{c}$") Plots.xlabel!(raw"$\log(a)$") Plots.ylabel!(raw"$v_{c}(a)$") -savefig("../plots/vc_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") +savefig("../plots/vc_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") function get_Φ′_Ψ(u,hierarchy::Hierarchy{T},x) where T @@ -196,7 +221,8 @@ nn_δ′,nn_v′ = zeros(length(hierarchytest.bg.x_grid)),zeros(length(hierarchy for j in 1:length(hierarchytest.bg.x_grid) Φ′_true[j],Ψ_true[j] = get_Φ′_Ψ(ode_data[:,j],hierarchytest,hierarchytest.bg.x_grid[j]) nnin = hcat([ode_data[:,j]...,hierarchytest.bg.x_grid[j]]) - nn_u′ = U(nnin,res1.u,st)[1] + # nn_u′ = U(nnin,res1.u,st)[1] + nn_u′ = U(nnin,res4.u,st)[1] nn_δ′[j],nn_v′[j] = nn_u′[1], nn_u′[2] end @@ -206,17 +232,27 @@ true_v′ = -v_true .- hierarchytestnn.k ./ hierarchytestnn.bg.ℋ .* Ψ_true Plots.plot(hierarchytestnn.bg.x_grid, true_δ′,label="truth",lw=2.5) Plots.plot!(hierarchytestnn.bg.x_grid,nn_δ′,label="nn-v1",lw=2.5) +Plots.plot!(hierarchytestnn.bg.x_grid,nn_δ′,label="nn-v1-full",lw=2.5) Plots.xlabel!(raw"$\log(a)$") Plots.ylabel!(raw"$\delta'(a)$") -Plots.title!(raw"recon $v_{c}$") -savefig("../plots/deltacprime_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") +Plots.title!(raw"recon $\delta_{c}$") +savefig("../plots/deltacprime_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") Plots.plot(hierarchytestnn.bg.x_grid,true_v′,label="truth",lw=2.5) Plots.plot!(hierarchytestnn.bg.x_grid,nn_v′,label="nn-v1",lw=2.5) +Plots.plot!(hierarchytestnn.bg.x_grid,nn_v′,label="nn-v1-full",lw=2.5) Plots.xlabel!(raw"$\log(a)$") Plots.ylabel!(raw"$v'(a)$") Plots.title!(raw"recon $v_{c}$") -savefig("../plots/vcprime_learning_v1_multnoise$(σfakeode)_Adam$(niter)_$(η).png") +savefig("../plots/vcprime_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") + + +# loss +Plots.plot(losses) +Plots.xlabel!(raw"iters") +Plots.ylabel!(raw"loss") +savefig("../plots/loss_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") + # -------------------------------- # Old code testing AD backends From 2220e79d7f7afd98027457272a2e80a5be9c246d Mon Sep 17 00:00:00 2001 From: James Sullivan Date: Tue, 25 Jul 2023 07:01:09 -0700 Subject: [PATCH 05/10] commnd line file --- Project.toml | 4 +- scripts/test_ude_fwddiff_deltac.jl | 220 +++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 scripts/test_ude_fwddiff_deltac.jl diff --git a/Project.toml b/Project.toml index d583527..b48845f 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -27,6 +27,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypergeometricFunctions = "34004b35-14d8-5ef3-9330-4cdb6864b03a" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -60,6 +61,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f" UnitfulCosmo = "961331e1-62bb-46d9-b9e3-f058129f1391" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/scripts/test_ude_fwddiff_deltac.jl b/scripts/test_ude_fwddiff_deltac.jl new file mode 100644 index 0000000..7a845b5 --- /dev/null +++ b/scripts/test_ude_fwddiff_deltac.jl @@ -0,0 +1,220 @@ +# Try replacing Lux with SimpleChains +seed = parse(Int64, ARGS[1]) +println("seed = $(seed)") + +import Pkg +Pkg.activate("/global/cfs/cdirs/m4051/ude/Bolt.jl") +using OrdinaryDiffEq +using Random +using LinearAlgebra, Statistics,Lux +using Bolt +using Plots +using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays,OptimizationOptimJL +using AbstractDifferentiation +import AbstractDifferentiation as AD, ForwardDiff +using DelimitedFiles +rng = Xoshiro(seed) + + +println("Done loading packages") +# setup ks won't use all of them here... +L=2f3 +lkmi,lkmax,nk = log10(2.0f0*π/L),log10(0.2f0),8 +kk = 10.0f0.^(collect(lkmi:(lkmax-lkmi)/(nk-1):lkmax)) +ℓᵧ=15 +pertlen=2(ℓᵧ+1)+5 + +# define network +m = 12 +U = Lux.Chain(Lux.Dense(pertlen+1, m, tanh), #input is u,t + Lux.Dense(m, m,tanh), + Lux.Dense(m, 2)) +p, st = Lux.setup(rng, U) + + + +# copy the hierarchy function os it works for nn - for this to work you need the Hierarchy_nn struct and unpack_nn in perturbations.jl +function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + + # the new free cdm index (not used here) + α_c = par.α_c + # get the nn params + p_nn = hierarchy.p + + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_nn(u, hierarchy) + Θ′, Θᵖ′, _, _, _, _, _ = unpack_nn(du, hierarchy) + + # Here I am throwing away the neutrinos entriely, which is probably bad + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c #ignore this + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + # matter + nnin = hcat([u...,x]) + û = U(nnin,p_nn,st)[1] + δ′ = û[1] + v′ = û[2] + # here we implicitly assume σ_c = 0 + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * Bolt.δ_kron(ℓ, 2) / 10) + end + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ + return nothing +end + +# use only the longest k mode +function hierarchy_nnu!(du, u, p, x) + hierarchy = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15); + hierarchy_nn!(du, u, hierarchy, x) +end + + + +# some setup +tspan = (-20.0f0, 0.0f0) +𝕡test = CosmoParams(Ω_c=0.3,α_c=-3.0); +bgtest = Background(𝕡test; x_grid=-20.0f0:1f-1:0.0f0); +ihtest = Bolt.get_saha_ih(𝕡test, bgtest); +hierarchytest = Hierarchy(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],15); +hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], ComponentArray(p),15); +u0 = Bolt.initial_conditions_nn(tspan[1],hierarchytestnn) + +# problem for truth and one we will remake +prob_trueode = ODEProblem(Bolt.hierarchy!, u0, tspan,hierarchytest) +prob_nn = ODEProblem(hierarchy_nnu!, u0, (bgtest.x_grid[1],bgtest.x_grid[end]), ComponentArray(p)) + +# Generate some noisy data (at reduced solver tolerance) +ode_data = Array(solve(prob_trueode, KenCarp4(), saveat = bgtest.x_grid, + abstol = 1e-3, reltol = 1e-3)) +δ_true,v_true = ode_data[end-3,:],ode_data[end-2,:] +σfakeode = 0.1 +noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode +noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode +Ytrain_ode = hcat([δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v]...) +# noise_both = Float32.(hcat([noise_fakeode,noise_fakeode_v]...)) +noise_both = Float32.(hcat([δ_true.*σfakeode,v_true.*σfakeode]...)) + +# NB I dropped the "Float64" type argument to the ComponentArray, maybe we should put it back? +function predict(θ, T = bgtest.x_grid) + _prob = remake(prob_nn, u0 = u0, tspan = (T[1], T[end]), p = θ) + res = Array(solve(_prob, KenCarp4(), saveat = T, + abstol = 1e-3, reltol = 1e-3)) + return hcat(res[end-3,:],res[end-2,:]) +end + +#log loss +# function loss(θ) +# X̂ = predict(θ) +# log(mean(abs2, (Ytrain_ode - X̂)./noise_both ) ) +# end +#raw loss +function loss(θ) + X̂ = predict(θ) + log(mean(abs2, (Ytrain_ode - X̂)) ) +end + +# adtype = Optimization.AutoZygote() +adtype = Optimization.AutoForwardDiff() +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) +optprob = Optimization.OptimizationProblem(optf, ComponentVector(p)) + + +# Training +losses,pp = [],[]; +callback = function (p, l) + push!(losses, l) + push!(pp, p) + if length(losses) % 5 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + end + return false +end +niter_0,η_0 = 20,1.0 +niter_1,niter_2,niter_3 =50,50,20 +η_1,η_2,η_3 = 1.0,0.1,0.01 +res1 = Optimization.solve(optprob, ADAM(η_1), callback = callback, maxiters = niter_1) +test_predict_o1 = predict(res1.u) + +optprob2 = remake(optprob,u0 = res1.u); +res2 = Optimization.solve(optprob2, ADAM(η_2), callback = callback, maxiters = niter_2) +test_predict_o2 = predict(res2.u) + +optprob3 = remake(optprob2,u0 = res2.u); +res3 = Optimization.solve(optprob3, ADAM(η_3), callback = callback, maxiters = niter_3) +test_predict_o3 = predict(res3.u) + +println("done optimizing") + +# save the results +writedlm("../../data/t_raw_enn_v1_seed$(seed).dat",predict(res3.u)) +writedlm("../../data/t_raw_loss_v1_seed$(seed).dat",losses) +writedlm("../../data/t_raw_params_v1_seed$(seed).dat",pp) + +println("done saving") + +function get_Φ′_Ψ(u,hierarchy::Hierarchy{T},x) where T + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + α_c = par.α_c + Θ, Θᵖ, Φ, δ, v, δ_b, v_b = unpack(u, hierarchy) + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + return Φ′,Ψ +end + +# The reconstructed function +Φ′_true,Ψ_true = zeros(length(hierarchytest.bg.x_grid)),zeros(length(hierarchytest.bg.x_grid)) +nn_δ′,nn_v′ = zeros(length(hierarchytest.bg.x_grid)),zeros(length(hierarchytest.bg.x_grid)) +for j in 1:length(hierarchytest.bg.x_grid) + Φ′_true[j],Ψ_true[j] = get_Φ′_Ψ(ode_data[:,j],hierarchytest,hierarchytest.bg.x_grid[j]) + nnin = hcat([ode_data[:,j]...,hierarchytest.bg.x_grid[j]]) + nn_u′ = U(nnin,res3.u,st)[1] + nn_δ′[j],nn_v′[j] = nn_u′[1], nn_u′[2] +end + + +true_δ′ = hierarchytestnn.k ./ hierarchytestnn.bg.ℋ .* v_true .- 3Φ′_true +true_v′ = -v_true .- hierarchytestnn.k ./ hierarchytestnn.bg.ℋ .* Ψ_true + +writedlm("../../data/t_raw_true_v1_seed$(seed).dat",[true_δ′,true_v′]) +writedlm("../../data/t_raw_recon_v1_seed$(seed).dat",[nn_δ′,nn_v′]) \ No newline at end of file From 286cc64fb35404908f8319eeca54a49c365c0771 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Tue, 25 Jul 2023 16:42:34 +0200 Subject: [PATCH 06/10] bg hmc on delta --- scripts/hmc_bg.jl | 138 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 scripts/hmc_bg.jl diff --git a/scripts/hmc_bg.jl b/scripts/hmc_bg.jl new file mode 100644 index 0000000..e62fbe5 --- /dev/null +++ b/scripts/hmc_bg.jl @@ -0,0 +1,138 @@ +using SimpleChains +using Plots,CairoMakie,PairPlots +using LinearAlgebra +using AdvancedHMC, ForwardDiff +using OrdinaryDiffEq +using LogDensityProblems +using BenchmarkTools +using DelimitedFiles +using Plots.PlotMeasures +using DataInterpolations +using Random,Distributions +rng = Xoshiro(123) + +# function to get a simplechain (fixed network architecture) +function train_initial_network(Xtrain,Ytrain,rng; λ=0.0f0, m=128,use_l_bias=true,use_L_bias=false,N_ADAM=1_000,N_epochs = 3,N_rounds=4) + #regularized network + d_in,d_out = size(Xtrain)[1],size(Ytrain)[1] + mlpd = SimpleChain( + static(d_in), + TurboDense{use_l_bias}(tanh, m), + TurboDense{use_l_bias}(tanh, m), + TurboDense{use_L_bias}(identity, d_out) #have not tested non-scalar output + ) + p = SimpleChains.init_params(mlpd;rng); + G = SimpleChains.alloc_threaded_grad(mlpd); + mlpdloss = SimpleChains.add_loss(mlpd, SquaredLoss(Ytrain)) + mlpdloss_reg = FrontLastPenalty(SimpleChains.add_loss(mlpd, SquaredLoss(Ytrain)), + L2Penalty(λ), L2Penalty(λ) ) + loss = λ > 0.0f0 ? mlpdloss_reg : mlpdloss + for k in 1:N_rounds + for _ in 1:N_epochs #FIXME, if not looking at this, don't need to split it up like so... + SimpleChains.train_unbatched!( + G, p, loss, Xtrain, SimpleChains.ADAM(), N_ADAM + ); + end + end + mlpd_noloss = SimpleChains.remove_loss(mlpd) + return mlpd_noloss,p +end + + +function run_hmc(ℓπ,initial_θ;n_samples=2_000,n_adapts=1_000,backend=ForwardDiff) + D=size(initial_θ) + metric = DiagEuclideanMetric(D)#ones(Float64,D).*0.01) + hamiltonian = Hamiltonian(metric, ℓπ, backend) + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = Leapfrog(initial_ϵ) + proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator) + adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) + samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true); + h_samples = hcat(samples...)' + return h_samples = hcat(samples...)',stats +end + +using Bolt +# first try just sampling with ODE over the 2 variables of interest with HMC and true ODE +function dm_model(p_dm::Array{DT,1};kMpch=0.01,ℓᵧ=15,reltol=1e-5,abstol=1e-5) where DT #k in h/Mpc + Ω_c,α_c = exp(p_dm[1]),-exp(p_dm[2]) #log pos params + # println("Ω_c,α_c = $(Ω_c) $(α_c)") + #standard code + 𝕡 = CosmoParams{DT}(Ω_c=Ω_c,α_c=α_c) + bg = Background(𝕡; x_grid=-20.0:0.1:0.0) + # 𝕣 = Bolt.RECFAST(bg=bg, Yp=𝕡.Y_p, OmegaB=𝕡.Ω_b, OmegaG=𝕡.Ω_r) + # ih = IonizationHistory(𝕣, 𝕡, bg) + ih= Bolt.get_saha_ih(𝕡, bg); + + k = 𝕡.h*kMpch #get k in our units + hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ) + results = boltsolve(hierarchy; reltol=reltol, abstol=abstol) + res = hcat(results.(bg.x_grid)...) + δ_c,v_c = res[end-3,:],res[end-2,:] + return δ_c#,v_c +end + +dm_model([log(0.3),log(- -3.0)]) + +# δ_true,v_true = dm_model([log(0.3),log(- -3.0)]); +δ_true= dm_model([log(0.3),log(- -3.0)]); +δ_true + +#test jacobian +jdmtest = ForwardDiff.jacobian(dm_model,[log(0.3),log(- -3.0)]) +jdmtest + +#HMC for 2 ode params +# this is obviously very artificial, multiplicative noise +σfakeode = 0.1 +noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode +# noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode +# Ytrain_ode = [δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v] +Ytrain_ode = δ_true .+ noise_fakeode +noise_fakeode +Plots.plot(δ_true) +Plots.scatter!(Ytrain_ode) + +# density functions for HMC +struct LogTargetDensity_ode + dim::Int +end + +LogDensityProblems.logdensity(p::LogTargetDensity_ode, θ) = -sum(abs2, (Ytrain_ode-dm_model(θ))./noise_fakeode) / 2 # standard multivariate normal +LogDensityProblems.dimension(p::LogTargetDensity_ode) = p.dim +LogDensityProblems.capabilities(::Type{LogTargetDensity_ode}) = LogDensityProblems.LogDensityOrder{0}() +ℓπode = LogTargetDensity_ode(2) + +initial_ode = [log(0.3),log(- -3.0)]; +ode_samples, ode_stats = run_hmc(ℓπode,initial_ode;n_samples=200,n_adapts=100) + +ode_samples + +ode_stats +ode_labels = Dict( + :α => "parameter 1", + :β => "parameter 2" +) + +# Annoyingly PairPlots provides very nice functionality only if you use DataFrames (or Tables) +α,β = ode_samples[:,1], ode_samples[:,2] +using DataFrames +df = DataFrame(;α,β) + +pairplot(df ,PairPlots.Truth( + (;α =log(0.3),β=log(- -3.0)), + label="Truth" ) ) + +LogDensityProblems.logdensity(ℓπode,[log(0.3),log(- -5.0)]) + +ForwardDiff.jacobian() + + +@btime LogDensityProblems.logdensity(ℓπode,initial_ode) +# 14.466 ms (7036 allocations: 1.73 MiB) +@profview LogDensityProblems.logdensity(ℓπode,initial_ode) +# so this is just slow, maybe not surprising + +# open("./test/ode_deltac_hmc_trueinit_samplesshort.dat", "w") do io +# writedlm(io, ode_samples) +# end From 3cac1fa9ab1e43c54d4ea21754f28a94a32cb3ba Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Tue, 25 Jul 2023 16:58:36 +0200 Subject: [PATCH 07/10] sc version of opt --- scripts/ude_fwddiff_deltac_sc.jl | 241 +++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 scripts/ude_fwddiff_deltac_sc.jl diff --git a/scripts/ude_fwddiff_deltac_sc.jl b/scripts/ude_fwddiff_deltac_sc.jl new file mode 100644 index 0000000..b563933 --- /dev/null +++ b/scripts/ude_fwddiff_deltac_sc.jl @@ -0,0 +1,241 @@ +using OrdinaryDiffEq +using Random +using LinearAlgebra, Statistics,Lux +rng = Xoshiro(123) +using Bolt +using Plots +using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays,OptimizationOptimJL +using AbstractDifferentiation +import AbstractDifferentiation as AD, ForwardDiff +using SimpleChains +using AdvancedHMC,LogDensityProblems + +# setup ks won't use all of them here... +L=2f3 +lkmi,lkmax,nk = log10(2.0f0*π/L),log10(0.2f0),8 +kk = 10.0f0.^(collect(lkmi:(lkmax-lkmi)/(nk-1):lkmax)) +ℓᵧ=15 +pertlen=2(ℓᵧ+1)+5 + +# define network +width=8 +# U = Lux.Chain(Lux.Dense(pertlen+1, width, tanh), #input is u,t +# Lux.Dense(width, width,tanh), +# Lux.Dense(width, 2)) +# p, st = Lux.setup(rng, U) +function get_nn(m,d_in) + NN = SimpleChain(static(d_in), + TurboDense{true}(tanh, m), + TurboDense{true}(tanh, m), + TurboDense{false}(identity, 2) #have not tested non-scalar output + ); + p = SimpleChains.init_params(NN;rng); + G = SimpleChains.alloc_threaded_grad(NN); + return NN,p,G +end + +NN,p,G = get_nn(width,pertlen+1) + + +# copy the hierarchy function os it works for nn - for this to work you need the Hierarchy_nn struct and unpack_nn in perturbations.jl + +function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + + # the new free cdm index (not used here) + α_c = par.α_c + # get the nn params + p_nn = hierarchy.p + + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_nn(u, hierarchy) + Θ′, Θᵖ′, _, _, _, _, _ = unpack_nn(du, hierarchy) + + # Here I am throwing away the neutrinos entriely, which is probably bad + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c #ignore this + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + + # matter + nnin = hcat([u...,x]) + # û = U(nnin,p_nn,st)[1] + # û = sc_noloss(nnin,p_nn)' .* std(true_u′,dims=1) .+ mean(true_u′,dims=1) + û = NN(nnin,p_nn)' + δ′ = û[1] + v′ = û[2] + # here we implicitly assume σ_c = 0 + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * Bolt.δ_kron(ℓ, 2) / 10) + end + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ + return nothing +end + +# use only the longest k mode +function hierarchy_nnu!(du, u, p, x) + hierarchy = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15); + hierarchy_nn!(du, u, hierarchy, x) +end + + + +# some setup +tspan = (-20.0f0, 0.0f0) +𝕡test = CosmoParams(Ω_c=0.3,α_c=-3.0); +bgtest = Background(𝕡test; x_grid=-20.0f0:1f-1:0.0f0); +ihtest = Bolt.get_saha_ih(𝕡test, bgtest); +hierarchytest = Hierarchy(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],15); +hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15); +u0 = Bolt.initial_conditions_nn(tspan[1],hierarchytestnn) + +NN(hcat([u0...,-20.f0]),p) +# NN([-20.f0],p) + +# problem for truth and one we will remake +prob_trueode = ODEProblem(Bolt.hierarchy!, u0, tspan,hierarchytest) +# prob_nn = ODEProblem(hierarchy_nnu!, u0, (bgtest.x_grid[1],bgtest.x_grid[end]), ComponentArray{Float64}(p)) + +# Generate some noisy data (at reduced solver tolerance) +ode_data = Array(solve(prob_trueode, KenCarp4(), saveat = bgtest.x_grid, + abstol = 1f-3, reltol = 1f-3)) +δ_true,v_true = ode_data[end-3,:],ode_data[end-2,:] +σfakeode = 0.1f0 +noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode +noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode +Ytrain_ode = hcat([δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v]...) +noise_both_old = Float32.(hcat([noise_fakeode,noise_fakeode_v]...)) +noise_both = Float32.(hcat([δ_true.*σfakeode,v_true.*σfakeode]...)) + +#float conversion +fl_xgrid = Float32.(bgtest.x_grid) +fu0 = Float32.(u0) +prob_nn = ODEProblem(hierarchy_nnu!, fu0, (fl_xgrid[1],fl_xgrid[end]), p) + +function predict(θ, T = fl_xgrid) + _prob = remake(prob_nn, u0 = fu0, tspan = (T[1], T[end]), p = θ) + res = Array(solve(_prob, KenCarp4(), saveat = T, + abstol = 1f-3, reltol = 1f-3)) + return hcat(res[end-3,:],res[end-2,:]) +end + +#log loss +function loss(θ) + X̂ = predict(θ) + log(sum(abs2, (Ytrain_ode - X̂)./ noise_both ) ) +end + + +# adtype = Optimization.AutoZygote() +adtype = Optimization.AutoForwardDiff() +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) +optprob = Optimization.OptimizationProblem(optf,p) + + + +# Training +function make_gif_plot(prediction,iter) + p = Plots.plot() + Plots.scatter!(p,bgtest.x_grid,Ytrain_ode[:,1],label="data") + Plots.plot!(p,bgtest.x_grid,prediction,lw=2,label="iter = $(iter)") + Plots.xlabel!(p,raw"$\log(a)$") + Plots.ylabel!(p,raw"$\delta_{c}(a)$") + savefig("../plots/learning_v1_multnoise$(σfakeode)_Adam$(iter)_s123_sc_chisq_to.png") + p +end + +losses = []; +pp =[]; +callback = function (p, l) + push!(losses, l) + push!(pp, p) + if length(losses) % 5 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + make_gif_plot(predict(p)[:,1],length(losses)) + end + return false +end +niter_1,niter_2,niter_3 =50,50,20 +η_1,η_2,η_3 = 1.f0,0.1f0,0.01f0 +#loss doesn't go down very much at all with 1f-3 η1 + +res1 = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = niter_1) +# this is pretty slow, on the order of half an hour, but it is at least running! +#Current loss after 50 iterations: 3.8709838260084415 +test_predict_o1 = predict(res1.u) + +optprob2 = remake(optprob,u0 = res1.u); +res2 = Optimization.solve(optprob2, ADAM(η_2), callback = callback, maxiters = niter_2) +test_predict_o2 = predict(res2.u) +#Current loss after 100 iterations: 1.1319215191941459 + +optprob3 = remake(optprob,u0 = res2.u); +res3 = Optimization.solve(optprob3, ADAM(η_3), callback = callback, maxiters = 100) +test_predict_o3 = predict(res3.u) +#Current loss after 120 iterations: 1.0862281116475199 + +# FIXME MORE INVOLVED ADAM SCHEDULE +# for the heck of it try BFGS, not sure what parameters it usually takes +optprob4 = remake(optprob3,u0 = res4.u); +# res4 = Optimization.solve(optprob4, BFGS(), +res4 = Optimization.solve(optprob4, ADAM(η_3/1000), callback = callback, + maxiters = 100) + # callback = callback, maxiters = 10) +test_predict_o4 = predict(res4.u) +# wow this is slow...I guess due to Hessian approximation? +# We have the gradient so idk why this takes so much longer than ADAM? +# Somehow loss actually goes up also? Maybe we overshoot, can try again with specified initial stepsize... +test_predict_o1[:,1] +# Plots of the learned perturbations +Plots.scatter(bgtest.x_grid,Ytrain_ode[:,1],label="data") +Plots.plot!(bgtest.x_grid,δ_true,label="truth",lw=2.5)#,yscale=:log10) +Plots.plot!(bgtest.x_grid,predict(pc)[:,1],label="opt-v1",lw=2.5,ls=:dash) +Plots.title!(raw"$\delta_{c}$") +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$\delta_{c}(a)$") +savefig("../plots/deltac_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") +pc +p +log10.(test_predict_o4[:,2]) + + +# It *SEEMS LIKE* the model isn't flexible enough - i.e. when I shift to +# weighted loss from square loss the late exponential part gets worse +# while the early part gets worse. +# For SC training earlier, regularization helped a bit but not much... + +Plots.scatter(bgtest.x_grid,Ytrain_ode[:,2],label="data")#,legend=:bottomright) +Plots.plot!(bgtest.x_grid,v_true,label="truth") +Plots.plot!(bgtest.x_grid,test_predict_o1[:,2],label="opt-v1") +Plots.plot!(bgtest.x_grid,test_predict_o4[:,2],label="opt-v1-full",lw=2.5,ls=:dot) +Plots.plot!(bgtest.x_grid,predict(pc)[:,2],label="opt-v1",lw=2.5,ls=:dash) +Plots.title!(raw"$v_{c}$") +Plots.xlabel!(raw"$\log(a)$") +Plots.ylabel!(raw"$v_{c}(a)$") +savefig("../plots/vc_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png") + From 731e81bdd1c18ba31de0504966b85189de2d6a62 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Fri, 8 Sep 2023 12:12:47 -0700 Subject: [PATCH 08/10] basic adj nn code --- scripts/ret_sc_adjoint.jl | 92 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 scripts/ret_sc_adjoint.jl diff --git a/scripts/ret_sc_adjoint.jl b/scripts/ret_sc_adjoint.jl new file mode 100644 index 0000000..0a20b5e --- /dev/null +++ b/scripts/ret_sc_adjoint.jl @@ -0,0 +1,92 @@ +using SimpleChains +using ForwardDiff +using OrdinaryDiffEq +using DataInterpolations +using Random,Distributions +rng = Xoshiro(123) +using Optimization,SciMLSensitivity,OptimizationOptimisers#,ComponentArrays # for fwd diff +using SciMLSensitivity + +# Setup +#------------------------------------------------------------------------ + +# general NN function +function get_nn(m,d_in) + NN = SimpleChain(static(d_in), + TurboDense{true}(tanh, m), + # TurboDense{true}(tanh, m), + TurboDense{false}(identity, 1) + ); + p = SimpleChains.init_params(NN;rng); + G = SimpleChains.alloc_threaded_grad(NN); + return NN,p,G +end + +# Simple ODE problem +function f_true!(du,u,p,t) + du .= p.*u + return nothing +end +u0 = [1.f0]; +tspan = (0.f0,1.f0); +λp = -1.0f0; +prob_true = ODEProblem(f_true!,u0,tspan, λp) +tt = 1f-2:1f-2:1.f0 +rtol=atol=1e-12 + +# Generate noisy "data" solution +σrel = 0.01f0; #multiplicative noise +true_soln = solve(prob_true,Tsit5(), saveat = tt, + abstol=atol,reltol=rtol) + +#training data with noise +Ytrain = vcat(Array(true_soln)...).*(1.f0.+σrel.*randn(rng,eltype(u0),size(tt))) +Ytrain_interp = CubicSpline(Ytrain,tt); #interpolated training data + +# neural network function (here we don't let it depend on t bc it doesn't need it, but in general it should) +function f_nn!(du,u,p,t) + du .= NN(u,p) + return nothing +end + +# NN ODE setup (initialization) +NN,p₀,G = get_nn(16,1); # shallow narrow network for testing +prob_nn = ODEProblem{true}(f_nn!,u0,tspan,p₀); +nn_soln = solve(prob_nn,Tsit5(),abstol=atol,reltol=rtol) + +function predict_interp(θ) + _prob = remake(prob_nn, u0 = u0, tspan = tspan, p = θ) + res = solve(_prob, Tsit5(), #saveat = collect(tt), + abstol = atol, reltol = rtol,dense=true) + return res +end + +function loss_cts(θ;Yinterp=Ytrain_interp,ttfine = 1f-2:1f-3:1.f0) + X̂ = predict_interp(θ)(ttfine) + sum((Yinterp(ttfine) .- Array(X̂)[1,:]).^2)*(ttfine[2]-ttfine[1]) #Riemann sum approx of cts integral +end + +# ---> This is where I had a long training step, but took it out for simplicity (can put it back) + +# Get the sciML adjoints +dg_cts(out, u, p, t) = (out .= -2.f0*(Ytrain_interp(t) - u[1])) +g_cts(u,p,t) = (Ytrain_interp(t) - u[1])^2 +sciml_adj_soln_cts = adjoint_sensitivities(predict_interp(p₀),# (p is not learned, so will be wrong) + Tsit5(), + g=g_cts, dgdu_continuous = dg_cts, + abstol = atol,reltol = rtol, # note in general we probably don't want to use the same fwd/back tols + # iabstol=atol_sciml,ireltol=rtol_sciml + sensealg = InterpolatingAdjoint(autojacvec=false, autodiff=true),#( autojacvec = true, autodiff=true) + ); + +# This is the adjoint gradient (structure of the returne thing is (λ(t=t₀), grad) ): +sciml_adj_grad_cts = sciml_adj_soln_cts[2] + + +# Compare to FwdDiffGradient +fwd_diff_grad_cts = ForwardDiff.gradient(loss_cts,p₀) + +prod( abs.(sciml_adj_grad_cts[1,:] .- fwd_diff_grad_cts) .<4f-5 )#true + + +# INSERT HMC CALLING THIS GRADIENT ON THE ABOVE LOSS HERE \ No newline at end of file From 61daa6eb48c5c4c857d3f87db225f3b584024895 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Fri, 15 Sep 2023 16:24:02 -0700 Subject: [PATCH 09/10] script with type and reversediff --- scripts/adjoint_boltsolve.jl | 375 +++++++++++++++++++++++++++++++++++ src/background.jl | 12 +- 2 files changed, 381 insertions(+), 6 deletions(-) create mode 100644 scripts/adjoint_boltsolve.jl diff --git a/scripts/adjoint_boltsolve.jl b/scripts/adjoint_boltsolve.jl new file mode 100644 index 0000000..bb46267 --- /dev/null +++ b/scripts/adjoint_boltsolve.jl @@ -0,0 +1,375 @@ +# Basic attempt to use sciml +using OrdinaryDiffEq, SciMLSensitivity +using SimpleChains +using Random +rng = Xoshiro(123); +using Plots +# Bolt boilerplate +using Bolt +kMpch=0.01; +ℓᵧ=3; +reltol=1e-5; +abstol=1e-5; +p_dm=[log(0.3),log(- -3.0)] +Ω_c,α_c = exp(p_dm[1]),-exp(p_dm[2]) #log pos params +#standard code +𝕡 = CosmoParams{Float32}(Ω_c=Ω_c,α_c=α_c) +𝕡 + +bg = Background(𝕡; x_grid=-20.0f0:0.1f0:0.0f0) + +typeof(Bolt.η(-20.f0,𝕡,zeros(Float32,5),zeros(Float32,5))) + + +ih= Bolt.get_saha_ih(𝕡, bg); +k = 𝕡.h*kMpch #get k in our units + +typeof(8π) + +# hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], p1,15); + +# hierarchytestspl = Hierarchy_spl(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], +# Bolt.spline(ones(length(bgtest.x_grid)),bgtest.x_grid), +# Bolt.spline(ones(length(bgtest.x_grid)),bgtest.x_grid), +# 15); + +# u0test = Bolt.initial_conditions_nn(-20.0,hierarchytestnn); + +# hierarchy_nn!(zeros(pertlen+2), u0test, hierarchytestnn, -20.0 ) + +hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ); +results = boltsolve(hierarchy; reltol=reltol, abstol=abstol); +res = hcat(results.(bg.x_grid)...) +# δ_c,v_c = res[end-3,:],res[end-2,:] + + +# Deconstruct the boltsolve... +#------------------------------------------- +xᵢ = first(Float32.(hierarchy.bg.x_grid)) +u₀ = Float32.(initial_conditions(xᵢ, hierarchy)); +# prob = ODEProblem{true}(hierarchy_nn!, u₀, (xᵢ , zero(T)), hierarchy) +# sol = solve(prob, ode_alg, reltol=reltol, abstol=abstol, +# saveat=hierarchy.bg.x_grid, +# ) +u₀ + +# NN = get_nn(4,pertlen+2,32) +# nnin = hcat([u...,k,x]) + +# NN setup +m=16; +pertlen=2(ℓᵧ+1)+5 +NN₁ = SimpleChain(static(pertlen+2), + TurboDense{true}(tanh, m), + TurboDense{true}(tanh, m), + TurboDense{false}(identity, 2) #have not tested non-scalar output + ); +p1 = SimpleChains.init_params(NN₁;rng); +G1 = SimpleChains.alloc_threaded_grad(NN₁); + + + +# plot the solution to see if jagged +Plots.plot(bg.x_grid,sol(bg.x_grid)[end-3,:]) + + + +Plots.plot!(bg.x_grid,results(bg.x_grid)[end-3,:]) +Plots.plot!(bg.x_grid,solt(bg.x_grid)[end-3,:],ls=:dash) +Plots.plot(bg.x_grid,sol(bg.x_grid)[end-4,:]) +Plots.plot!(bg.x_grid,results(bg.x_grid)[end-4,:]) +Plots.plot!(bg.x_grid,solt(bg.x_grid)[end-4,:],ls=:dash) + + + +function hierarchy_nn_p!(du, u, p, x; hierarchy=hierarchy,NN=NN₁) + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + α_c = par.α_c + Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + Θ′, Θᵖ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 + + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] + # + Ω_c * a^(4+α_c) * σ_c + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[0] + ) + # matter + nnin = hcat([u...,k,x]) + println(size(nnin)) + u′ = NN(nnin,p) + println(size(u′)) + δ′ = u′[1] #k / ℋₓ * v - 3Φ′ + v′ = u′[2] + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) + # photons + Π = Θ[2] + Θᵖ[2] + Θᵖ[0] + Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + for ℓ in 2:(ℓᵧ-1) + Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + # polarized photons + Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + for ℓ in 1:(ℓᵧ-1) + Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + end + # photon boundary conditions: diffusion damping + Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + #END RSA + + du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + return nothing +end + + +function hierarchy_nn_p(u, p, x; hierarchy=hierarchy,NN=NN₁) + # compute cosmological quantities at time x, and do some unpacking + k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih + Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 + ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) + a = x2a(x) + R = 4Ω_r / (3Ω_b * a) + csb² = ih.csb²(x) + α_c = par.α_c + + Θ = [u[1],u[2],u[3],u[4]] + Θᵖ = [u[5],u[6],u[7],u[8]] + Φ, δ_c, v_c,δ_b, v_b = u[9:end] + # Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) + # Θ′, Θᵖ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 + + # metric perturbations (00 and ij FRW Einstein eqns) + Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[3] + # + Ω_c * a^(4+α_c) * σ_c + ) + + Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( + Ω_c * a^(2+α_c) * δ_c + + Ω_b * a^(-1) * δ_b + + 4Ω_r * a^(-2) * Θ[1] + ) + # matter + nnin = hcat([u...,k,x]) + u′ = NN(nnin,p) + # δ′, v′ = NN(nnin,p) + δ′ = u′[1] #k / ℋₓ * v - 3Φ′ + v′ = u′[2] + + δ_b′ = k / ℋₓ * v_b - 3Φ′ + v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[2] + v_b) + # photons + Π = Θ[3] + Θᵖ[3] + Θᵖ[1] + + # Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ + # Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) + Θ′0 = -k / ℋₓ * Θ[2] - Φ′ + Θ′1 = k / (3ℋₓ) * Θ[1] - 2k / (3ℋₓ) * Θ[3] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[2] + v_b/3) + Θ′2 = 2 * k / ((2*2+1) * ℋₓ) * Θ[3-1] - + (2+1) * k / ((2*2+1) * ℋₓ) * Θ[3+1] + τₓ′ * (Θ[3] - Π * δ_kron(2, 2) / 10) + # for ℓ in 2:(ℓᵧ-1 ) + # Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - + # (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + # end + # polarized photons + # Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) + Θᵖ′0 = -k / ℋₓ * Θᵖ[2] + τₓ′ * (Θᵖ[1] - Π / 2) + # for ℓ in 1:(ℓᵧ-1) + Θᵖ′1 = 1 * k / ((2*1+1) * ℋₓ) * Θᵖ[2-1] - + (1+1) * k / ((2*1+1) * ℋₓ) * Θᵖ[2+1] + τₓ′ * (Θᵖ[2] - Π * δ_kron(1, 2) / 10) + Θᵖ′2 = 2 * k / ((2*2+1) * ℋₓ) * Θᵖ[3-1] - + (2+1) * k / ((2*2+1) * ℋₓ) * Θᵖ[3+1] + τₓ′ * (Θᵖ[3] - Π * δ_kron(2, 2) / 10) + # end + # for ℓ in 1:(ℓᵧ-1) + # Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - + # (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) + # end + # photon boundary conditions: diffusion damping + # Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] + # Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] + Θ′3 = k / ℋₓ * Θ[4-1] - ( (3 + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[4] + Θᵖ′3 = k / ℋₓ * Θᵖ[4-1] - ( (3 + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[4] + # du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in + # println("type return: ", typeof([Θ′0, Θ′1, Θ′2, Θ′3, Θᵖ′0, Θᵖ′1, Θᵖ′2, Θᵖ′3 , Φ′, δ′, v′, δ_b′, v_b′ ]) ) + return [Θ′0, Θ′1, Θ′2, Θ′3, Θᵖ′0, Θᵖ′1, Θᵖ′2, Θᵖ′3 , Φ′, δ′, v′, δ_b′, v_b′ ] +end +#------------------------------------------- +# test the input +hierarchy_nn_p(u₀,p1,-20.0f0) +typeof(u₀) + +du_test_p = rand(pertlen+2); +hierarchy_nn_p!(du_test_p,u₀,p1,-20.0) + +du_test_p + +u₀ +prob = ODEProblem{false}(hierarchy_nn_p, u₀, (xᵢ , zero(Float32)), p1) +sol = solve(prob, KenCarp4(), reltol=reltol, abstol=abstol, + saveat=hierarchy.bg.x_grid, + ) +# attempt to solve + +#now try again with the sciml sensitivity + +# Write some loss functions (sciml g functions) in terms of P(k) for *single mode* +# But first do some basic tests +dg_dscr(out, u, p, t, i) = (out.=-1.0.+u) +ts = -20.0:1.0:0.0 + +# This is dLdP * dPdu bc dg is really dgdu +PL = Bolt.get_Pk # Make some kind of function like this extracting from existing Pk function by feeding in u. +Nmodes = # do the calculation for some fiducial box size (motivated by a survey) +σₙ = P ./ Nmodes +dPdL = -2.f0 * (data-P) ./ σₙ.^2 # basic diagonal loss +dPdu = # the thing to extract, some weighted ratio of the matter density perturbations, so only 3 elements of u +dg_dscr_P(out,u,p,t,i) =(out.= -1.0) + + + +g_cts(u,p,t) = (sum(u).^2) ./ 2.f0 +dg_cts(out, u, p, t) = (out .= 2.f0*sum(u)) + +# This seems like it takes forever even just to produce an error? Is it just the first time or is it because +# it has to do finite diff? +# maybe it is too many ts? Maybe the network just sucks so the solver takes a long time... + +# discrete +@time res = adjoint_sensitivities(sol,#results, + KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=false,autodiff=false);t=ts,dgdu_discrete=dg_dscr,abstol=abstol, + reltol=reltol); # works (or at least spits out something) +# after the first one, this time returns: +#326.239041 seconds (3.36 G allocations: 102.113 GiB, 7.42% gc time) - for full bg ts +#140.153200 seconds (1.46 G allocations: 44.512 GiB, 7.59% gc time) - for ts with spacing dx=1.0 + +# continuous +@time res = adjoint_sensitivities(sol,#results, + KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=true);dgdu_continuous=dg_cts,g=g_cts,abstol=abstol, + reltol=reltol); # +# after the first one, this time returns: +#111.447511 seconds (1.24 G allocations: 37.651 GiB, 7.26% gc time) + +# Alright great so this won't work until we can make Bolt take an array as input, even abstract/component array +# I bet Zack actually tried this previously and this is what he found? Can ask him... + +# Idk if we can do this with the code because of the interpolators... +# What does hierarchy actually hold that is the problem? + + + +@time res = adjoint_sensitivities(sol, + KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)); + dgdu_continuous=dg_cts,g=g_cts,abstol=abstol, + reltol=reltol); # + + +""" +it holds some ints for maximum boltzmann hierarchy, obviously this is not a problem for compnent arrays + +Also what are we even talking about? We just want to make it recognize the NN parameters as p right? + +So we should consider hierarchy etc. fixed, and it should be no problem to solve the backwards thing with that struct +We can just make it global? + +Didn't I already encounter this in May-ish? What did I do there? I think I redefined boltsolve +Clearly the boltsolve here is NOT RIGHT + +Ok but even the one I defined to extend Hierarchy to hold NN parameters won't work in this sense bc still holds interpolators + +Can we just hack the code to make Hierarchy struct in to an "AbstractArray"? +Ask Zack/Marius...wait for an answer + +Alright, so, what I'll do for now is just make this file hold a global set of cosmological parameters. +Then we can freeze a global version of the background and ionization history interpolators + Later, As Zack said, we can, without too much trouble, extend the code to just take a component array of Chebyshev weights, + and then construct our interpolators out of those every time instead of using the splines. + (Or we could even use the bsplines, we just save the bases and the weights, do what I did for the du/du' thing for every query of it) + We just lose the overall functionality and ease of use we've already relied on Julia packages for... +After that those no longer have to be passed to the hierarchy! function, the only thing that will have to be passed is NN weights +At that point this SciML machinery should just work... +So I basically just need to redefine hierarchy! to only take the nn parameters and maek sure it can access the global interpolators +(In principle we could just run HMC over cosmo params this way by managing the scope a little better, but would be gross and probably slow) + +This should not be hard -> 30 min exercise... + +Ok this kind of works. + +Now: +1. Update the cost functions to be something we actually care about (power spectrum in a single k mode, generalize to multiple later) +2. Try the autodiff vecs (which will probably break) -> it does +3. Rough performance numbers (profile) what is taking so long? -> performannce btwn 2 is similar, reducing dscrt pts by factor of 10 is only performance gain of 2 +4. Embed this in an optimization framework (make it a black box gradient function call) and test that gradient +5. Run that optimization framework (which may have to be on nersc, which is fine) +6. Is this actually faster than doing forwarddiff for the gradient? It should be...in principle +7. Can we *make* it faster than forwarddiff gradient? (Tolerance relaxation over training, fewer discrete points etc.) + +""" + +# Does pretraining with "good" parameters make this faster, since it is smooth? + +#check what the derivative function actually looks like for this solution +#^It looks right early on when dm doesn't matter so much, then later on in MD looks super wrong +#^Basically what we expect + +# Try pre-training on \LambdaCDM +# 1. Generate f(u,t) from true solution u +# 2. Train NN on f(u,t) (w/ e.g. SC train_unbatched) +# 3. Now evaluate the derivative with adjoints and compare the times to before +# (if it is much faster this is an argument to enforce smoothness somehow) +#^or maybe we just initialize the weights to be numerically small, but not zero? +aa = zeros(Float64,pertlen) +hierarchy_nn_p!(aa,results(ts[1]),p1,ts[1]) +aa +# train SC network for many steps +fu_test = zeros(Float32,length(bg.x_grid),pertlen); +for i in 1:length(bg.x_grid) + fu_test[i,:] .= hierarchy_nn_p(results(bg.x_grid[i]),p1,bg.x_grid[i]) + # println(fu_test[i,:]) +end + +fu_test_dc_vc = hcat(fu_test[:,end-3],fu_test[:,end-2]); +fu_test_dc_vc + + +scloss = SimpleChains.add_loss(NN₁, SquaredLoss(fu_test_dc_vc)) +λ=1e-1 +scloss_reg = FrontLastPenalty(scloss, + L2Penalty(λ), L2Penalty(λ) ) +N_ADAM = 100_000 +typeof(results(ts)) +# X_train = hcat(Array(results(ts))', ts, k.*ones(length(ts))) +X_train = hcat(Array(results(bg.x_grid))', k.*ones(length(bg.x_grid)),bg.x_grid) +size(X_train') + +NN₁(X_train[1,:],p1) +NN₁(X_train',p1) +scloss(X_train[1,:],p1) +scloss(X_train',p1) +scloss(X_train,p1) +X_train' +SimpleChains.train_unbatched!( + G1, p1, scloss, + X_train', SimpleChains.ADAM(), N_ADAM + ); +p1 +scloss(X_train',p1) + + + +solt = solve(remake(prob), KenCarp4(), u0 = u₀, tspan = (-20.0, 0.0), p = p1 + ); diff --git a/src/background.jl b/src/background.jl index 3478faf..78dc206 100644 --- a/src/background.jl +++ b/src/background.jl @@ -3,8 +3,8 @@ const ζ = 1.2020569 #Riemann ζ(3) for phase space integrals -H₀(par::AbstractCosmoParams) = par.h * km_s_Mpc_100 -ρ_crit(par::AbstractCosmoParams) = (3 / 8π) * H₀(par)^2 / G_natural +H₀(par::AbstractCosmoParams{T}) = par.h * T(km_s_Mpc_100) +ρ_crit(par::AbstractCosmoParams{T}) = T( (3 / 8π) * H₀(par)^2 / G_natural ) function Ω_Λ(par::AbstractCosmoParams) #Below can definitely be more streamlined, I am just making it work for now return 1 - (par.Ω_r + par.Ω_b + par.Ω_c) @@ -25,9 +25,9 @@ H(x, par::AbstractCosmoParams) = H_a(x2a(x),par) ℋ(x, par::AbstractCosmoParams) = ℋ_a(x2a(x), par) # conformal time -function η(x, par::AbstractCosmoParams,quad_pts,quad_wts) - logamin,logamax=-13.75,log10(x2a(x)) - Iη(y) = 1.0 / (xq2q(y,logamin,logamax) * ℋ_a(xq2q(y,logamin,logamax), par))/ dxdq(xq2q(y,logamin,logamax),logamin,logamax) +function η(x, par::AbstractCosmoParams{T},quad_pts,quad_wts) where T + logamin,logamax=-T(13.75),log10(x2a(x)) + Iη(y) = 1 / (T(xq2q(y,logamin,logamax)) * ℋ_a(T(xq2q(y,logamin,logamax)), par))/ T(dxdq(xq2q(y,logamin,logamax),logamin,logamax)) return sum(Iη.(quad_pts).*quad_wts) end @@ -60,7 +60,7 @@ function Background(par::AbstractCosmoParams{T}; x_grid=-20.0:0.01:0.0, nq=15) w η_ = spline([η(x, par,quad_pts,quad_wts) for x in x_grid], x_grid) return Background( T(H₀(par)), - T(η(0.0, par,quad_pts,quad_wts)), + T(η(zero(T), par,quad_pts,quad_wts)), T(ρ_crit(par)), T(Ω_Λ(par)), From cc70300b596eba23ee6e30ba43eb99c9198cece1 Mon Sep 17 00:00:00 2001 From: Jamie Sullivan Date: Fri, 15 Sep 2023 16:29:48 -0700 Subject: [PATCH 10/10] update --- scripts/adjoint_boltsolve.jl | 235 ++--------------------------------- 1 file changed, 8 insertions(+), 227 deletions(-) diff --git a/scripts/adjoint_boltsolve.jl b/scripts/adjoint_boltsolve.jl index bb46267..2e247db 100644 --- a/scripts/adjoint_boltsolve.jl +++ b/scripts/adjoint_boltsolve.jl @@ -26,35 +26,15 @@ k = 𝕡.h*kMpch #get k in our units typeof(8π) -# hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], p1,15); - -# hierarchytestspl = Hierarchy_spl(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1], -# Bolt.spline(ones(length(bgtest.x_grid)),bgtest.x_grid), -# Bolt.spline(ones(length(bgtest.x_grid)),bgtest.x_grid), -# 15); - -# u0test = Bolt.initial_conditions_nn(-20.0,hierarchytestnn); - -# hierarchy_nn!(zeros(pertlen+2), u0test, hierarchytestnn, -20.0 ) - hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ); results = boltsolve(hierarchy; reltol=reltol, abstol=abstol); res = hcat(results.(bg.x_grid)...) -# δ_c,v_c = res[end-3,:],res[end-2,:] # Deconstruct the boltsolve... #------------------------------------------- xᵢ = first(Float32.(hierarchy.bg.x_grid)) u₀ = Float32.(initial_conditions(xᵢ, hierarchy)); -# prob = ODEProblem{true}(hierarchy_nn!, u₀, (xᵢ , zero(T)), hierarchy) -# sol = solve(prob, ode_alg, reltol=reltol, abstol=abstol, -# saveat=hierarchy.bg.x_grid, -# ) -u₀ - -# NN = get_nn(4,pertlen+2,32) -# nnin = hcat([u...,k,x]) # NN setup m=16; @@ -68,76 +48,6 @@ p1 = SimpleChains.init_params(NN₁;rng); G1 = SimpleChains.alloc_threaded_grad(NN₁); - -# plot the solution to see if jagged -Plots.plot(bg.x_grid,sol(bg.x_grid)[end-3,:]) - - - -Plots.plot!(bg.x_grid,results(bg.x_grid)[end-3,:]) -Plots.plot!(bg.x_grid,solt(bg.x_grid)[end-3,:],ls=:dash) -Plots.plot(bg.x_grid,sol(bg.x_grid)[end-4,:]) -Plots.plot!(bg.x_grid,results(bg.x_grid)[end-4,:]) -Plots.plot!(bg.x_grid,solt(bg.x_grid)[end-4,:],ls=:dash) - - - -function hierarchy_nn_p!(du, u, p, x; hierarchy=hierarchy,NN=NN₁) - # compute cosmological quantities at time x, and do some unpacking - k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih - Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2 - ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x) - a = x2a(x) - R = 4Ω_r / (3Ω_b * a) - csb² = ih.csb²(x) - α_c = par.α_c - Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack) - Θ′, Θᵖ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6 - - # metric perturbations (00 and ij FRW Einstein eqns) - Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2] - # + Ω_c * a^(4+α_c) * σ_c - ) - - Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * ( - Ω_c * a^(2+α_c) * δ_c - + Ω_b * a^(-1) * δ_b - + 4Ω_r * a^(-2) * Θ[0] - ) - # matter - nnin = hcat([u...,k,x]) - println(size(nnin)) - u′ = NN(nnin,p) - println(size(u′)) - δ′ = u′[1] #k / ℋₓ * v - 3Φ′ - v′ = u′[2] - - δ_b′ = k / ℋₓ * v_b - 3Φ′ - v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b) - # photons - Π = Θ[2] + Θᵖ[2] + Θᵖ[0] - Θ′[0] = -k / ℋₓ * Θ[1] - Φ′ - Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3) - for ℓ in 2:(ℓᵧ-1) - Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] - - (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10) - end - # polarized photons - Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2) - for ℓ in 1:(ℓᵧ-1) - Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] - - (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10) - end - # photon boundary conditions: diffusion damping - Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ] - Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ] - #END RSA - - du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in - return nothing -end - - function hierarchy_nn_p(u, p, x; hierarchy=hierarchy,NN=NN₁) # compute cosmological quantities at time x, and do some unpacking k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih @@ -223,153 +133,24 @@ prob = ODEProblem{false}(hierarchy_nn_p, u₀, (xᵢ , zero(Float32)), p1) sol = solve(prob, KenCarp4(), reltol=reltol, abstol=abstol, saveat=hierarchy.bg.x_grid, ) -# attempt to solve -#now try again with the sciml sensitivity - -# Write some loss functions (sciml g functions) in terms of P(k) for *single mode* -# But first do some basic tests dg_dscr(out, u, p, t, i) = (out.=-1.0.+u) ts = -20.0:1.0:0.0 -# This is dLdP * dPdu bc dg is really dgdu -PL = Bolt.get_Pk # Make some kind of function like this extracting from existing Pk function by feeding in u. -Nmodes = # do the calculation for some fiducial box size (motivated by a survey) -σₙ = P ./ Nmodes -dPdL = -2.f0 * (data-P) ./ σₙ.^2 # basic diagonal loss -dPdu = # the thing to extract, some weighted ratio of the matter density perturbations, so only 3 elements of u -dg_dscr_P(out,u,p,t,i) =(out.= -1.0) - - +# # This is dLdP * dPdu bc dg is really dgdu +# PL = Bolt.get_Pk # Make some kind of function like this extracting from existing Pk function by feeding in u. +# Nmodes = # do the calculation for some fiducial box size (motivated by a survey) +# σₙ = P ./ Nmodes +# dPdL = -2.f0 * (data-P) ./ σₙ.^2 # basic diagonal loss +# dPdu = # the thing to extract, some weighted ratio of the matter density perturbations, so only 3 elements of u +# dg_dscr_P(out,u,p,t,i) =(out.= -1.0) g_cts(u,p,t) = (sum(u).^2) ./ 2.f0 dg_cts(out, u, p, t) = (out .= 2.f0*sum(u)) -# This seems like it takes forever even just to produce an error? Is it just the first time or is it because -# it has to do finite diff? -# maybe it is too many ts? Maybe the network just sucks so the solver takes a long time... - -# discrete -@time res = adjoint_sensitivities(sol,#results, - KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=false,autodiff=false);t=ts,dgdu_discrete=dg_dscr,abstol=abstol, - reltol=reltol); # works (or at least spits out something) -# after the first one, this time returns: -#326.239041 seconds (3.36 G allocations: 102.113 GiB, 7.42% gc time) - for full bg ts -#140.153200 seconds (1.46 G allocations: 44.512 GiB, 7.59% gc time) - for ts with spacing dx=1.0 - -# continuous -@time res = adjoint_sensitivities(sol,#results, - KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=true);dgdu_continuous=dg_cts,g=g_cts,abstol=abstol, - reltol=reltol); # -# after the first one, this time returns: -#111.447511 seconds (1.24 G allocations: 37.651 GiB, 7.26% gc time) - -# Alright great so this won't work until we can make Bolt take an array as input, even abstract/component array -# I bet Zack actually tried this previously and this is what he found? Can ask him... - -# Idk if we can do this with the code because of the interpolators... -# What does hierarchy actually hold that is the problem? - - +# @time res = adjoint_sensitivities(sol, KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)); dgdu_continuous=dg_cts,g=g_cts,abstol=abstol, reltol=reltol); # - - -""" -it holds some ints for maximum boltzmann hierarchy, obviously this is not a problem for compnent arrays - -Also what are we even talking about? We just want to make it recognize the NN parameters as p right? - -So we should consider hierarchy etc. fixed, and it should be no problem to solve the backwards thing with that struct -We can just make it global? - -Didn't I already encounter this in May-ish? What did I do there? I think I redefined boltsolve -Clearly the boltsolve here is NOT RIGHT - -Ok but even the one I defined to extend Hierarchy to hold NN parameters won't work in this sense bc still holds interpolators - -Can we just hack the code to make Hierarchy struct in to an "AbstractArray"? -Ask Zack/Marius...wait for an answer - -Alright, so, what I'll do for now is just make this file hold a global set of cosmological parameters. -Then we can freeze a global version of the background and ionization history interpolators - Later, As Zack said, we can, without too much trouble, extend the code to just take a component array of Chebyshev weights, - and then construct our interpolators out of those every time instead of using the splines. - (Or we could even use the bsplines, we just save the bases and the weights, do what I did for the du/du' thing for every query of it) - We just lose the overall functionality and ease of use we've already relied on Julia packages for... -After that those no longer have to be passed to the hierarchy! function, the only thing that will have to be passed is NN weights -At that point this SciML machinery should just work... -So I basically just need to redefine hierarchy! to only take the nn parameters and maek sure it can access the global interpolators -(In principle we could just run HMC over cosmo params this way by managing the scope a little better, but would be gross and probably slow) - -This should not be hard -> 30 min exercise... - -Ok this kind of works. - -Now: -1. Update the cost functions to be something we actually care about (power spectrum in a single k mode, generalize to multiple later) -2. Try the autodiff vecs (which will probably break) -> it does -3. Rough performance numbers (profile) what is taking so long? -> performannce btwn 2 is similar, reducing dscrt pts by factor of 10 is only performance gain of 2 -4. Embed this in an optimization framework (make it a black box gradient function call) and test that gradient -5. Run that optimization framework (which may have to be on nersc, which is fine) -6. Is this actually faster than doing forwarddiff for the gradient? It should be...in principle -7. Can we *make* it faster than forwarddiff gradient? (Tolerance relaxation over training, fewer discrete points etc.) - -""" - -# Does pretraining with "good" parameters make this faster, since it is smooth? - -#check what the derivative function actually looks like for this solution -#^It looks right early on when dm doesn't matter so much, then later on in MD looks super wrong -#^Basically what we expect - -# Try pre-training on \LambdaCDM -# 1. Generate f(u,t) from true solution u -# 2. Train NN on f(u,t) (w/ e.g. SC train_unbatched) -# 3. Now evaluate the derivative with adjoints and compare the times to before -# (if it is much faster this is an argument to enforce smoothness somehow) -#^or maybe we just initialize the weights to be numerically small, but not zero? -aa = zeros(Float64,pertlen) -hierarchy_nn_p!(aa,results(ts[1]),p1,ts[1]) -aa -# train SC network for many steps -fu_test = zeros(Float32,length(bg.x_grid),pertlen); -for i in 1:length(bg.x_grid) - fu_test[i,:] .= hierarchy_nn_p(results(bg.x_grid[i]),p1,bg.x_grid[i]) - # println(fu_test[i,:]) -end - -fu_test_dc_vc = hcat(fu_test[:,end-3],fu_test[:,end-2]); -fu_test_dc_vc - - -scloss = SimpleChains.add_loss(NN₁, SquaredLoss(fu_test_dc_vc)) -λ=1e-1 -scloss_reg = FrontLastPenalty(scloss, - L2Penalty(λ), L2Penalty(λ) ) -N_ADAM = 100_000 -typeof(results(ts)) -# X_train = hcat(Array(results(ts))', ts, k.*ones(length(ts))) -X_train = hcat(Array(results(bg.x_grid))', k.*ones(length(bg.x_grid)),bg.x_grid) -size(X_train') - -NN₁(X_train[1,:],p1) -NN₁(X_train',p1) -scloss(X_train[1,:],p1) -scloss(X_train',p1) -scloss(X_train,p1) -X_train' -SimpleChains.train_unbatched!( - G1, p1, scloss, - X_train', SimpleChains.ADAM(), N_ADAM - ); -p1 -scloss(X_train',p1) - - - -solt = solve(remake(prob), KenCarp4(), u0 = u₀, tspan = (-20.0, 0.0), p = p1 - );