Skip to content

Commit

Permalink
rename _dist postfixes to _prior when used as a prior
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Feb 23, 2024
1 parent 39fd99a commit 4ea4c3c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
function default_rw_priors()
return (
var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_rw_value_dist = Normal()
var_RW_prior = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_rw_value_prior = Normal()
)
end

@model function random_walk(n; var_RW_dist, init_rw_value_dist)
@model function random_walk(n; var_RW_prior, init_rw_value_prior)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ var_RW_dist
init ~ init_rw_value_dist
σ²_RW ~ var_RW_prior
init ~ init_rw_value_prior
σ_RW = sqrt(σ²_RW)
rw = Vector{eltype(ϵ_t)}(undef, n)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Plots.PlotMeasures
using EpiAware
Random.seed!(0)
n = 30
latent_process_priors = (var_RW_dist = truncated(Normal(0.0, 0.5), 0.0, Inf),)
latent_process_priors = (var_RW_prior = truncated(Normal(0.0, 0.5), 0.0, Inf),)

model = random_walk(n; latent_process_priors = latent_process_priors)
n_samples = 2000
Expand All @@ -20,7 +20,7 @@ sampled_walks = prior_chn |> chn -> mapreduce(hcat, generated_quantities(model,
gen[1]
end
## From law of total variance and known mean of HalfNormal distribution
theoretical_std = [t * latent_process_priors.var_RW_dist.untruncated.σ * sqrt(2) / sqrt(π)
theoretical_std = [t * latent_process_priors.var_RW_prior.untruncated.σ * sqrt(2) / sqrt(π)
for t in 1:n] .|> sqrt

plt_ppc_rw = plot(
Expand All @@ -46,7 +46,7 @@ plot!(
)
plot!(
σ_hist,
latent_process_priors.var_RW_dist,
latent_process_priors.var_RW_prior,
lw = 2,
c = :red,
alpha = 0.5,
Expand Down
8 changes: 4 additions & 4 deletions EpiAware/test/test_latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
(var(samples_day_5) - 5) > -5 * theoretical_std_of_empiral_var
end
@testitem "Testing default_rw_priors" begin
@testset "var_RW_dist" begin
@testset "var_RW_prior" begin
priors = default_rw_priors()
var_RW = rand(priors.var_RW_dist)
var_RW = rand(priors.var_RW_prior)
@test var_RW >= 0.0
end

@testset "init_rw_value_dist" begin
@testset "init_rw_value_prior" begin
priors = default_rw_priors()
init_rw_value = rand(priors.init_rw_value_dist)
init_rw_value = rand(priors.init_rw_value_prior)
@test typeof(init_rw_value) == Float64
end
end

0 comments on commit 4ea4c3c

Please sign in to comment.