Skip to content

Commit

Permalink
Merge pull request #61 from CDCgov/45-improve-the-prior-interface-for…
Browse files Browse the repository at this point in the history
…-make_epi_model_inference

Improve the prior interface for make epi model inference
  • Loading branch information
seabbs authored Feb 21, 2024
2 parents d672c2c + 792b4bd commit 42cfdc6
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 62 deletions.
33 changes: 25 additions & 8 deletions EpiAware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
- Solid lines indicate implemented features/analysis.
- Dashed lines indicate planned features/analysis.

## Proposed `EpiAware` model diagram
```mermaid
flowchart TD
flowchart LR
A["Underlying dists.
and specify length of sims
Expand All @@ -29,24 +30,40 @@ C["Observational Data
Obs. cases y_t"]
D["Latent processes
---------------------
Random Walk"]
E[Turing model constructor]
F["Latent Process priors"]
random_walk"]
E["Turing model constructor
---------------------
make_epi_inference_model"]
F["Latent Process priors
---------------------
default_rw_priors"]
G[Posterior draws]
H[Posterior checking]
I[Post-processing]
DataW[Data wrangling and QC]
J["Observation Model
J["Observation models
---------------------
delay_observations"]
K["Observation model priors
---------------------
default_delay_obs_priors"]
ObservationModel["ObservationModel
---------------------
delay_observations_model"]
LatentProcess["LatentProcess
---------------------
random_walk_process"]
A --> EpiModel
B --> EpiModel
EpiModel -->E
C-->E
D-->|random_walk| E
J-->E
F-->|default_rw_priors|E
D-->LatentProcess
F-->LatentProcess
J-->ObservationModel
K-->ObservationModel
LatentProcess-->E
ObservationModel-->E
E-->|sample...NUTS...| G
G-.->H
H-.->I
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, random_walk, delay_observations
export make_epi_inference_model, delay_observations_model, random_walk_process

include("epimodel.jl")
include("utilities.jl")
include("models.jl")
include("latent-processes.jl")
include("observation-processes.jl")
include("models.jl")

end
31 changes: 31 additions & 0 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,34 @@ end
end
return rw, init, (; σ_RW,)
end

"""
struct LatentProcess{F<:Function}
A struct representing a latent process with its priors.
# Fields
- `latent_process`: The latent process function for a `Turing` model.
- `latent_process_priors`: NamedTuple containing the priors for the latent process.
"""
struct LatentProcess{F <: Function}
latent_process::F
latent_process_priors::NamedTuple
end

"""
random_walk_process(; latent_process_priors = default_rw_priors())
Create a `LatentProcess` struct reflecting a random walk process with optional priors.
# Arguments
- `latent_process_priors`: Optional priors for the random walk process.
# Returns
- `LatentProcess`: A random walk process.
"""
function random_walk_process(; latent_process_priors = default_rw_priors())
LatentProcess(random_walk, latent_process_priors)
end
15 changes: 8 additions & 7 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
@model function make_epi_inference_model(
y_t,
epimodel::AbstractEpiModel,
latent_process,
observation_process;
process_priors,
latent_process_obj::LatentProcess,
observation_process_obj::ObservationModel;
pos_shift = 1e-6
)
#Latent process
time_steps = epimodel.data.time_horizon
@submodel latent_process, init, latent_process_aux = latent_process(
time_steps; latent_process_priors = process_priors)
@submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process(
time_steps;
latent_process_priors = latent_process_obj.latent_process_priors
)

#Transform into infections
I_t = epimodel(latent_process, init)

#Predictive distribution of ascerted cases
@submodel generated_y_t, generated_y_t_aux = observation_process(
@submodel generated_y_t, generated_y_t_aux = observation_process_obj.observation_model(
y_t,
I_t,
epimodel::AbstractEpiModel;
observation_process_priors = process_priors,
observation_process_priors = observation_process_obj.observation_model_priors,
pos_shift = pos_shift
)

Expand Down
31 changes: 31 additions & 0 deletions EpiAware/src/observation-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,34 @@ end

return y_t, (; neg_bin_cluster_factor,)
end

"""
struct ObservationModel{F<:Function}
A struct representing an observation model with its priors.
# Fields
- `observation_model`: The observation model function for a `Turing` model.
- `observation_model_priors`: NamedTuple containing the priors for the observation model.
"""
struct ObservationModel{F <: Function}
observation_model::F
observation_model_priors::NamedTuple
end

"""
delay_observations_model(; latent_process_priors = default_rw_priors())
Create an `ObservationModel` struct reflecting a delayed observation process with optional priors.
# Arguments
- `latent_process_priors`: Optional priors for the delayed observation process.
# Returns
- `ObservationModel`: An observation model with delayed observations.
"""
function delay_observations_model(; observation_model_priors = default_delay_obs_priors())
ObservationModel(delay_observations, observation_model_priors)
end
19 changes: 4 additions & 15 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,16 @@ In this case we use the `DirectInfections` model.
=#

toy_log_infs = DirectInfections(model_data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

#=
## Generate a `Turing` `Model`
We don't have observed data, so we use `missing` value for `y_t`.
=#

log_infs_model = make_epi_inference_model(
missing,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6
)
missing, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

#=
## Sample from the model
Expand Down Expand Up @@ -147,14 +143,7 @@ We treat the generated data as observed data and attempt to infer underlying inf

truth_data = random_epidemic.y_t

model = make_epi_inference_model(
truth_data,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6
)
model = make_epi_inference_model(truth_data, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

@time chn = sample(
model,
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@testitem "Testing random_walk against theoretical properties" begin
using DynamicPPL, Turing
n = 5
model = random_walk(n)
model = EpiAware.random_walk(n)
fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 = sample(fixed_model, Prior(), n_samples) |>
Expand Down
37 changes: 9 additions & 28 deletions EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = DirectInfections(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand All @@ -38,20 +31,14 @@ end
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = ExpGrowthRate(data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand All @@ -76,16 +63,10 @@ end
pos_shift = 1e-6

epimodel = Renewal(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
observation_process_priors = default_delay_obs_priors()

# Call the function
mdl = delay_observations(
mdl = EpiAware.delay_observations(
missing,
I_t,
epimodel;
Expand Down

0 comments on commit 42cfdc6

Please sign in to comment.