From 8035d6032fcdc0eef9bbaa55cbb7802b38bab8e2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 23 Feb 2024 13:30:41 +0000 Subject: [PATCH 1/2] Changed `Turing` models intended to be submodels into constructs taking `kwargs` variable splits rather than `NamedTuple`s; updated tests --- EpiAware/src/latent-processes.jl | 6 +++--- EpiAware/src/models.jl | 6 +++--- EpiAware/src/observation-processes.jl | 7 +++---- EpiAware/test/test_latent-processes.jl | 3 ++- EpiAware/test/test_observation-processes.jl | 5 +++-- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index b85a775eb..27b68ed0f 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -5,10 +5,10 @@ function default_rw_priors() ) end -@model function random_walk(n; latent_process_priors = default_rw_priors()) +@model function random_walk(n; kwargs...) ϵ_t ~ MvNormal(ones(n)) - σ²_RW ~ latent_process_priors.var_RW_dist - init ~ latent_process_priors.init_rw_value_dist + σ²_RW ~ kwargs[:var_RW_dist] + init ~ kwargs[:init_rw_value_dist] σ_RW = sqrt(σ²_RW) rw = Vector{eltype(ϵ_t)}(undef, n) diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 8d896de85..8b3e73064 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -9,7 +9,7 @@ time_steps = epimodel.data.time_horizon @submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process( time_steps; - latent_process_priors = latent_process_obj.latent_process_priors + latent_process_obj.latent_process_priors... ) #Transform into infections @@ -20,8 +20,8 @@ y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = observation_process_obj.observation_model_priors, - pos_shift = pos_shift + pos_shift = pos_shift, + observation_process_obj.observation_model_priors... ) #Generate quantities diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index 3389a5cfd..34171c55b 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -6,14 +6,13 @@ end y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = default_delay_obs_priors(), - pos_shift = 1e-6 + kwargs... ) #Parameters - neg_bin_cluster_factor ~ observation_process_priors.neg_bin_cluster_factor_prior + neg_bin_cluster_factor ~ kwargs[:neg_bin_cluster_factor_prior] #Predictive distribution - case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ pos_shift .|> + case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ kwargs[:pos_shift] .|> μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) #Likelihood diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index 5541e515a..867b90640 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -2,7 +2,8 @@ @testitem "Testing random_walk against theoretical properties" begin using DynamicPPL, Turing n = 5 - model = EpiAware.random_walk(n) + priors = EpiAware.default_rw_priors() + model = EpiAware.random_walk(n; priors...) fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process n_samples = 1000 samples_day_5 = sample(fixed_model, Prior(), n_samples) |> diff --git a/EpiAware/test/test_observation-processes.jl b/EpiAware/test/test_observation-processes.jl index 1386d3bb5..ae239ab70 100644 --- a/EpiAware/test/test_observation-processes.jl +++ b/EpiAware/test/test_observation-processes.jl @@ -8,14 +8,15 @@ data = EpiData([0.2, 0.3, 0.5], [1.0], 0.8, 3, exp) epimodel = DirectInfections(data) # Set up priors - observation_process_priors = default_delay_obs_priors() + priors = default_delay_obs_priors() # Call the function mdl = EpiAware.delay_observations( missing, I_t, epimodel; - observation_process_priors = observation_process_priors + pos_shift = 1e-6, + priors... ) fix_mdl = fix(mdl, neg_bin_cluster_factor = 0.00001) # Effectively Poisson sampling From ee63f3dbc39b2f2cf20eff48087a18506d83bc4f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 23 Feb 2024 14:02:32 +0000 Subject: [PATCH 2/2] Remove unnecessary `kwargs...` --- EpiAware/src/latent-processes.jl | 6 +++--- EpiAware/src/observation-processes.jl | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index 27b68ed0f..c6ecd9ebd 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -5,10 +5,10 @@ function default_rw_priors() ) end -@model function random_walk(n; kwargs...) +@model function random_walk(n; var_RW_dist, init_rw_value_dist) ϵ_t ~ MvNormal(ones(n)) - σ²_RW ~ kwargs[:var_RW_dist] - init ~ kwargs[:init_rw_value_dist] + σ²_RW ~ var_RW_dist + init ~ init_rw_value_dist σ_RW = sqrt(σ²_RW) rw = Vector{eltype(ϵ_t)}(undef, n) diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index 34171c55b..cc5528d13 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -6,13 +6,14 @@ end y_t, I_t, epimodel::AbstractEpiModel; - kwargs... + neg_bin_cluster_factor_prior, + pos_shift ) #Parameters - neg_bin_cluster_factor ~ kwargs[:neg_bin_cluster_factor_prior] + neg_bin_cluster_factor ~ neg_bin_cluster_factor_prior #Predictive distribution - case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ kwargs[:pos_shift] .|> + case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ pos_shift .|> μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) #Likelihood