Skip to content

Commit

Permalink
Refactor prior construction and make clearer the logic (#566)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored Dec 23, 2024
1 parent 1b3d397 commit af6dfb6
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 87 deletions.
187 changes: 109 additions & 78 deletions pipeline/src/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,92 +3,123 @@ Constructs and returns a latent model based on the provided `inference_config` a
The purpose of this function is to make adjustments to the latent model based on the
full `inference_config` provided.
The `tscale` argument is used to scale the standard deviation of the latent model based on the
idea that some processes have a variance that is (approximately) proportional to a time period (due to non-stationarity)
and some processes have a variance that is constant in time (at stationarity). The default
value is `sqrt(21.0)`, which corresponds to matching the variance of stationary processes to
the eventual variance of non-stationary process after 21 days.
The `pipeline` argument is used for dispatch purposes.
# Returns
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
The prior decisions are based on the target standard deviation and autocorrelation of the latent process,
which are determined by the infection generating process (igp) and whether the latent process is stationary or non-stationary
via the `_make_target_std_and_autocorr` function.
# Details
- The function first constructs a dictionary of priors using `make_model_priors(pipeline)`.
- It then retrieves the `igp` (inference generation process) and `latent_model_name` from `inference_config`.
- Depending on the `latent_model_name` and `igp`, it constructs and returns the appropriate latent model:
- `"diff_ar"`: Constructs a `DiffLatentModel` with an `AR` model.
- `"ar"`: Constructs an `AR` model.
- `"rw"`: Constructs a `RandomWalk` model.
- The priors for the models are set based on the `prior_dict` and the `tscale` parameter.
# Returns
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
"""
function remake_latent_model(inference_config::Dict,
pipeline::AbstractRtwithoutRenewalPipeline; tscale = sqrt(21.0))
function remake_latent_model(
inference_config::Dict, pipeline::AbstractRtwithoutRenewalPipeline)
#Baseline choices
prior_dict = make_model_priors(pipeline)
igp = inference_config["igp"]
latent_model_name = inference_config["latent_namemodels"].first

if latent_model_name == "diff_ar"
if igp == Renewal
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = HalfNormal(0.05 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
elseif igp == ExpGrowthRate
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = HalfNormal(0.005 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
elseif igp == DirectInfections
ar = AR(damp_priors = [Beta(9, 1)],
std_prior = HalfNormal(0.05 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
end
elseif latent_model_name == "ar"
if igp == Renewal
ar = AR(damp_priors = [Beta(2, 8)],
std_prior = HalfNormal(0.25),
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
elseif igp == ExpGrowthRate
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = HalfNormal(0.025),
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
elseif igp == DirectInfections
ar = AR(damp_priors = [Beta(9, 1)],
std_prior = HalfNormal(0.25),
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
end
elseif latent_model_name == "rw"
if igp == Renewal
rw = RandomWalk(
std_prior = HalfNormal(0.05 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
elseif igp == ExpGrowthRate
rw = RandomWalk(
std_prior = HalfNormal(0.005 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
elseif igp == DirectInfections
rw = RandomWalk(
std_prior = HalfNormal(0.1 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
end
end
default_latent_model = inference_config["latent_namemodels"].second
target_std, target_autocorr = default_latent_model isa AR ?
_make_target_std_and_autocorr(igp; stationary = true) :
_make_target_std_and_autocorr(igp; stationary = false)

return _implement_latent_process(
target_std, target_autocorr, default_latent_model, pipeline)
end

"""
This function sets the target standard deviation for an infection generating process (igp)
based on whether the latent process representation of its dynamics are stationary or non-stationary.
## Stationary Processes
- For Renewal process `log(R_t)` in the long run a fluctuation of 0.75 (e.g. ~ 75% of the mean) is not unexpected.
- For Exponential Growth Rate process `r_t` in the long run a fluctuation of 0.2 is not unexpected e.g. going from
`rt = 0.1` (7 day doubling time) to `rt = -0.1` (7 day halving time) is a 0.2 time-to-time fluctuation.
- For Direct Infections process `log(I_t)` in the long run a fluctuation of 2.0 (i.e a couple of orders of magnitude) is not unexpected.
For stationary latent processes Direct Infections and rt processes the autocorrelation is expected to be high at 0.9,
because persistence in residual away from mean is expected. Otherwise, the autocorrelation is expected to be 0.1.
## Non-Stationary Processes
For Renewal process `log(R_t)` in a single time step a fluctuation of 0.025 (e.g. ~ 2.5% of the mean) is not unexpected.
For Exponential Growth Rate process `r_t` in a single time step a fluctuation of 0.005 is not unexpected.
For Direct Infections process `log(I_t)` in a single time step a fluctuation of 0.025 is not unexpected.
The autocorrelation is expected to be 0.1.
"""
function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool)
return stationary ? (0.75, 0.1) : (0.025, 0.1)
end

function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool)
return stationary ? (0.2, 0.9) : (0.005, 0.1)
end

function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool)
return stationary ? (2.0, 0.9) : (0.25, 0.1)
end

function _make_new_prior_dict(target_std, target_autocorr,
pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size)
#Get default priors
prior_dict = make_model_priors(pipeline)
#Adjust priors based on target autocorrelation and standard deviation
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
(1 - target_autocorr) * beta_eff_sample_size)
corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
noise_prior = HalfNormal(target_std)
init_prior = prior_dict["transformed_process_init_prior"]
return Dict(
"transformed_process_init_prior" => init_prior,
"corr_corrected_noise_prior" => corr_corrected_noise_prior,
"noise_prior" => noise_prior,
"damp_param_prior" => damp_prior
)
end

"""
Constructs and returns a latent model based on an approximation to the specified target standard deviation and autocorrelation.
NB: The stationary variance of an AR(1) process is given by `σ² = σ²_ε / (1 - ρ²)` where `σ²_ε` is the variance of the noise and `ρ` is the autocorrelation.
The approximation here are based on `E[1/(1 - ρ²)`] ≈ 1 / (1 - E[ρ²])` which only holds for fairly tight distributions of `ρ`.
However, for priors this should get the expected order of magnitude.
# Models
- `"diff_ar"`: Constructs a `DiffLatentModel` with an autoregressive (AR) process.
- `"ar"`: Constructs an autoregressive (AR) process.
- `"rw"`: Constructs a random walk (RW) process.
"""
function _implement_latent_process(
target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10)
prior_dict = make_model_priors(pipeline)
new_priors = _make_new_prior_dict(
target_std, target_autocorr, pipeline; beta_eff_sample_size)

return _make_latent(default_latent_model, new_priors)
end

function _make_latent(::AR, new_priors)
damp_prior = new_priors["damp_param_prior"]
corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return AR(damp_priors = [damp_prior],
std_prior = corr_corrected_noise_std,
init_priors = [init_prior])
end

function _make_latent(::DiffLatentModel, new_priors)
init_prior = new_priors["transformed_process_init_prior"]
ar = _make_latent(AR(), new_priors)
return DiffLatentModel(; model = ar, init_priors = [init_prior])
end

function _make_latent(::RandomWalk, new_priors)
noise_std = new_priors["noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return RandomWalk(std_prior = noise_std, init_prior = init_prior)
end

"""
Expand Down
20 changes: 11 additions & 9 deletions pipeline/test/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,51 @@
)
end
pipeline = MockPipeline()

ar = AR()
diff_ar = DiffLatentModel(model = ar)
rw = RandomWalk()
@testset "diff_ar model" begin
inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => ("diff_ar" => "diff_ar"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("diff_ar", diff_ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa DiffLatentModel
@test model.model isa AR

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => ("diff_ar" => "diff_ar"))
"igp" => DirectInfections, "latent_namemodels" => Pair("diff_ar", diff_ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa DiffLatentModel
@test model.model isa AR
end

@testset "ar model" begin
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", "ar"))
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR

inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", "ar"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", "ar"))
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR
end

@testset "rw model" begin
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", "rw"))
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk

inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", "rw"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", "rw"))
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk
end
Expand Down

0 comments on commit af6dfb6

Please sign in to comment.