Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early nn stuff #89

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,64 @@ authors = ["Zack Li", "Jamie Sullivan"]
version = "0.1.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypergeometricFunctions = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NaturallyUnitful = "872cf16e-200e-11e9-2cdf-8bb39cfbec41"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
NumericalIntegration = "e7bfaba1-d571-5449-8927-abc22e82249b"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f"
UnitfulCosmo = "961331e1-62bb-46d9-b9e3-f058129f1391"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
156 changes: 156 additions & 0 deletions scripts/adjoint_boltsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Basic attempt to use sciml
using OrdinaryDiffEq, SciMLSensitivity
using SimpleChains
using Random
rng = Xoshiro(123);
using Plots
# Bolt boilerplate
using Bolt
kMpch=0.01;
ℓᵧ=3;
reltol=1e-5;
abstol=1e-5;
p_dm=[log(0.3),log(- -3.0)]
Ω_c,α_c = exp(p_dm[1]),-exp(p_dm[2]) #log pos params
#standard code
𝕡 = CosmoParams{Float32}(Ω_c=Ω_c,α_c=α_c)
𝕡

bg = Background(𝕡; x_grid=-20.0f0:0.1f0:0.0f0)

typeof(Bolt.η(-20.f0,𝕡,zeros(Float32,5),zeros(Float32,5)))


ih= Bolt.get_saha_ih(𝕡, bg);
k = 𝕡.h*kMpch #get k in our units

typeof(8π)

hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ);
results = boltsolve(hierarchy; reltol=reltol, abstol=abstol);
res = hcat(results.(bg.x_grid)...)


# Deconstruct the boltsolve...
#-------------------------------------------
xᵢ = first(Float32.(hierarchy.bg.x_grid))
u₀ = Float32.(initial_conditions(xᵢ, hierarchy));

# NN setup
m=16;
pertlen=2(ℓᵧ+1)+5
NN₁ = SimpleChain(static(pertlen+2),
TurboDense{true}(tanh, m),
TurboDense{true}(tanh, m),
TurboDense{false}(identity, 2) #have not tested non-scalar output
);
p1 = SimpleChains.init_params(NN₁;rng);
G1 = SimpleChains.alloc_threaded_grad(NN₁);


function hierarchy_nn_p(u, p, x; hierarchy=hierarchy,NN=NN₁)
# compute cosmological quantities at time x, and do some unpacking
k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih
Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2
ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.ℋ(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x)
a = x2a(x)
R = 4Ω_r / (3Ω_b * a)
csb² = ih.csb²(x)
α_c = par.α_c

Θ = [u[1],u[2],u[3],u[4]]
Θᵖ = [u[5],u[6],u[7],u[8]]
Φ, δ_c, v_c,δ_b, v_b = u[9:end]
# Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack(u, hierarchy) # the Θ, Θᵖ, 𝒩 are views (see unpack)
# Θ′, Θᵖ′, _, _, _, _, _ = unpack(du, hierarchy) # will be sweetened by .. syntax in 1.6

# metric perturbations (00 and ij FRW Einstein eqns)
Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[3]
# + Ω_c * a^(4+α_c) * σ_c
)

Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * (
Ω_c * a^(2+α_c) * δ_c
+ Ω_b * a^(-1) * δ_b
+ 4Ω_r * a^(-2) * Θ[1]
)
# matter
nnin = hcat([u...,k,x])
u′ = NN(nnin,p)
# δ′, v′ = NN(nnin,p)
δ′ = u′[1] #k / ℋₓ * v - 3Φ′
v′ = u′[2]

δ_b′ = k / ℋₓ * v_b - 3Φ′
v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[2] + v_b)
# photons
Π = Θ[3] + Θᵖ[3] + Θᵖ[1]

# Θ′[0] = -k / ℋₓ * Θ[1] - Φ′
# Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3)
Θ′0 = -k / ℋₓ * Θ[2] - Φ′
Θ′1 = k / (3ℋₓ) * Θ[1] - 2k / (3ℋₓ) * Θ[3] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[2] + v_b/3)
Θ′2 = 2 * k / ((2*2+1) * ℋₓ) * Θ[3-1] -
(2+1) * k / ((2*2+1) * ℋₓ) * Θ[3+1] + τₓ′ * (Θ[3] - Π * δ_kron(2, 2) / 10)
# for ℓ in 2:(ℓᵧ-1 )
# Θ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ-1] -
# (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * δ_kron(ℓ, 2) / 10)
# end
# polarized photons
# Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2)
Θᵖ′0 = -k / ℋₓ * Θᵖ[2] + τₓ′ * (Θᵖ[1] - Π / 2)
# for ℓ in 1:(ℓᵧ-1)
Θᵖ′1 = 1 * k / ((2*1+1) * ℋₓ) * Θᵖ[2-1] -
(1+1) * k / ((2*1+1) * ℋₓ) * Θᵖ[2+1] + τₓ′ * (Θᵖ[2] - Π * δ_kron(1, 2) / 10)
Θᵖ′2 = 2 * k / ((2*2+1) * ℋₓ) * Θᵖ[3-1] -
(2+1) * k / ((2*2+1) * ℋₓ) * Θᵖ[3+1] + τₓ′ * (Θᵖ[3] - Π * δ_kron(2, 2) / 10)
# end
# for ℓ in 1:(ℓᵧ-1)
# Θᵖ′[ℓ] = ℓ * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ-1] -
# (ℓ+1) * k / ((2ℓ+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10)
# end
# photon boundary conditions: diffusion damping
# Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ]
# Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ]
Θ′3 = k / ℋₓ * Θ[4-1] - ( (3 + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[4]
Θᵖ′3 = k / ℋₓ * Θᵖ[4-1] - ( (3 + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[4]
# du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′ # put non-photon perturbations back in
# println("type return: ", typeof([Θ′0, Θ′1, Θ′2, Θ′3, Θᵖ′0, Θᵖ′1, Θᵖ′2, Θᵖ′3 , Φ′, δ′, v′, δ_b′, v_b′ ]) )
return [Θ′0, Θ′1, Θ′2, Θ′3, Θᵖ′0, Θᵖ′1, Θᵖ′2, Θᵖ′3 , Φ′, δ′, v′, δ_b′, v_b′ ]
end
#-------------------------------------------
# test the input
hierarchy_nn_p(u₀,p1,-20.0f0)
typeof(u₀)

du_test_p = rand(pertlen+2);
hierarchy_nn_p!(du_test_p,u₀,p1,-20.0)

du_test_p

u₀
prob = ODEProblem{false}(hierarchy_nn_p, u₀, (xᵢ , zero(Float32)), p1)
sol = solve(prob, KenCarp4(), reltol=reltol, abstol=abstol,
saveat=hierarchy.bg.x_grid,
)

dg_dscr(out, u, p, t, i) = (out.=-1.0.+u)
ts = -20.0:1.0:0.0

# # This is dLdP * dPdu bc dg is really dgdu
# PL = Bolt.get_Pk # Make some kind of function like this extracting from existing Pk function by feeding in u.
# Nmodes = # do the calculation for some fiducial box size (motivated by a survey)
# σₙ = P ./ Nmodes
# dPdL = -2.f0 * (data-P) ./ σₙ.^2 # basic diagonal loss
# dPdu = # the thing to extract, some weighted ratio of the matter density perturbations, so only 3 elements of u
# dg_dscr_P(out,u,p,t,i) =(out.= -1.0)

g_cts(u,p,t) = (sum(u).^2) ./ 2.f0
dg_cts(out, u, p, t) = (out .= 2.f0*sum(u))

#

@time res = adjoint_sensitivities(sol,
KenCarp4(),sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true));
dgdu_continuous=dg_cts,g=g_cts,abstol=abstol,
reltol=reltol); #
2 changes: 1 addition & 1 deletion scripts/first_plin.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Revise
# using Revise
using Bolt
using ForwardDiff
using Plots
Expand Down
138 changes: 138 additions & 0 deletions scripts/hmc_bg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
using SimpleChains
using Plots,CairoMakie,PairPlots
using LinearAlgebra
using AdvancedHMC, ForwardDiff
using OrdinaryDiffEq
using LogDensityProblems
using BenchmarkTools
using DelimitedFiles
using Plots.PlotMeasures
using DataInterpolations
using Random,Distributions
rng = Xoshiro(123)

# function to get a simplechain (fixed network architecture)
function train_initial_network(Xtrain,Ytrain,rng; λ=0.0f0, m=128,use_l_bias=true,use_L_bias=false,N_ADAM=1_000,N_epochs = 3,N_rounds=4)
#regularized network
d_in,d_out = size(Xtrain)[1],size(Ytrain)[1]
mlpd = SimpleChain(
static(d_in),
TurboDense{use_l_bias}(tanh, m),
TurboDense{use_l_bias}(tanh, m),
TurboDense{use_L_bias}(identity, d_out) #have not tested non-scalar output
)
p = SimpleChains.init_params(mlpd;rng);
G = SimpleChains.alloc_threaded_grad(mlpd);
mlpdloss = SimpleChains.add_loss(mlpd, SquaredLoss(Ytrain))
mlpdloss_reg = FrontLastPenalty(SimpleChains.add_loss(mlpd, SquaredLoss(Ytrain)),
L2Penalty(λ), L2Penalty(λ) )
loss = λ > 0.0f0 ? mlpdloss_reg : mlpdloss
for k in 1:N_rounds
for _ in 1:N_epochs #FIXME, if not looking at this, don't need to split it up like so...
SimpleChains.train_unbatched!(
G, p, loss, Xtrain, SimpleChains.ADAM(), N_ADAM
);
end
end
mlpd_noloss = SimpleChains.remove_loss(mlpd)
return mlpd_noloss,p
end


function run_hmc(ℓπ,initial_θ;n_samples=2_000,n_adapts=1_000,backend=ForwardDiff)
D=size(initial_θ)
metric = DiagEuclideanMetric(D)#ones(Float64,D).*0.01)
hamiltonian = Hamiltonian(metric, ℓπ, backend)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))
samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true);
h_samples = hcat(samples...)'
return h_samples = hcat(samples...)',stats
end

using Bolt
# first try just sampling with ODE over the 2 variables of interest with HMC and true ODE
function dm_model(p_dm::Array{DT,1};kMpch=0.01,ℓᵧ=15,reltol=1e-5,abstol=1e-5) where DT #k in h/Mpc
Ω_c,α_c = exp(p_dm[1]),-exp(p_dm[2]) #log pos params
# println("Ω_c,α_c = $(Ω_c) $(α_c)")
#standard code
𝕡 = CosmoParams{DT}(Ω_c=Ω_c,α_c=α_c)
bg = Background(𝕡; x_grid=-20.0:0.1:0.0)
# 𝕣 = Bolt.RECFAST(bg=bg, Yp=𝕡.Y_p, OmegaB=𝕡.Ω_b, OmegaG=𝕡.Ω_r)
# ih = IonizationHistory(𝕣, 𝕡, bg)
ih= Bolt.get_saha_ih(𝕡, bg);

k = 𝕡.h*kMpch #get k in our units
hierarchy = Hierarchy(BasicNewtonian(), 𝕡, bg, ih, k, ℓᵧ)
results = boltsolve(hierarchy; reltol=reltol, abstol=abstol)
res = hcat(results.(bg.x_grid)...)
δ_c,v_c = res[end-3,:],res[end-2,:]
return δ_c#,v_c
end

dm_model([log(0.3),log(- -3.0)])

# δ_true,v_true = dm_model([log(0.3),log(- -3.0)]);
δ_true= dm_model([log(0.3),log(- -3.0)]);
δ_true

#test jacobian
jdmtest = ForwardDiff.jacobian(dm_model,[log(0.3),log(- -3.0)])
jdmtest

#HMC for 2 ode params
# this is obviously very artificial, multiplicative noise
σfakeode = 0.1
noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode
# noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode
# Ytrain_ode = [δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v]
Ytrain_ode = δ_true .+ noise_fakeode
noise_fakeode
Plots.plot(δ_true)
Plots.scatter!(Ytrain_ode)

# density functions for HMC
struct LogTargetDensity_ode
dim::Int
end

LogDensityProblems.logdensity(p::LogTargetDensity_ode, θ) = -sum(abs2, (Ytrain_ode-dm_model(θ))./noise_fakeode) / 2 # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity_ode) = p.dim
LogDensityProblems.capabilities(::Type{LogTargetDensity_ode}) = LogDensityProblems.LogDensityOrder{0}()
ℓπode = LogTargetDensity_ode(2)

initial_ode = [log(0.3),log(- -3.0)];
ode_samples, ode_stats = run_hmc(ℓπode,initial_ode;n_samples=200,n_adapts=100)

ode_samples

ode_stats
ode_labels = Dict(
:α => "parameter 1",
:β => "parameter 2"
)

# Annoyingly PairPlots provides very nice functionality only if you use DataFrames (or Tables)
α,β = ode_samples[:,1], ode_samples[:,2]
using DataFrames
df = DataFrame(;α,β)

pairplot(df ,PairPlots.Truth(
(;α =log(0.3),β=log(- -3.0)),
label="Truth" ) )

LogDensityProblems.logdensity(ℓπode,[log(0.3),log(- -5.0)])

ForwardDiff.jacobian()


@btime LogDensityProblems.logdensity(ℓπode,initial_ode)
# 14.466 ms (7036 allocations: 1.73 MiB)
@profview LogDensityProblems.logdensity(ℓπode,initial_ode)
# so this is just slow, maybe not surprising

# open("./test/ode_deltac_hmc_trueinit_samplesshort.dat", "w") do io
# writedlm(io, ode_samples)
# end
Loading