From af6dfb6ccc6505d6573de3f6fc61b4c7434b9ec1 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:48:01 +0000 Subject: [PATCH] Refactor prior construction and make clearer the logic (#566) --- .../src/constructors/remake_latent_model.jl | 187 ++++++++++-------- .../test/constructors/remake_latent_model.jl | 20 +- 2 files changed, 120 insertions(+), 87 deletions(-) diff --git a/pipeline/src/constructors/remake_latent_model.jl b/pipeline/src/constructors/remake_latent_model.jl index 964f5d448..06599d712 100644 --- a/pipeline/src/constructors/remake_latent_model.jl +++ b/pipeline/src/constructors/remake_latent_model.jl @@ -3,92 +3,123 @@ Constructs and returns a latent model based on the provided `inference_config` a The purpose of this function is to make adjustments to the latent model based on the full `inference_config` provided. -The `tscale` argument is used to scale the standard deviation of the latent model based on the -idea that some processes have a variance that is (approximately) proportional to a time period (due to non-stationarity) -and some processes have a variance that is constant in time (at stationarity). The default -value is `sqrt(21.0)`, which corresponds to matching the variance of stationary processes to -the eventual variance of non-stationary process after 21 days. - The `pipeline` argument is used for dispatch purposes. -# Returns -- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`. +The prior decisions are based on the target standard deviation and autocorrelation of the latent process, +which are determined by the infection generating process (igp) and whether the latent process is stationary or non-stationary +via the `_make_target_std_and_autocorr` function. -# Details -- The function first constructs a dictionary of priors using `make_model_priors(pipeline)`. -- It then retrieves the `igp` (inference generation process) and `latent_model_name` from `inference_config`. -- Depending on the `latent_model_name` and `igp`, it constructs and returns the appropriate latent model: - - `"diff_ar"`: Constructs a `DiffLatentModel` with an `AR` model. - - `"ar"`: Constructs an `AR` model. - - `"rw"`: Constructs a `RandomWalk` model. -- The priors for the models are set based on the `prior_dict` and the `tscale` parameter. +# Returns +- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`. """ -function remake_latent_model(inference_config::Dict, - pipeline::AbstractRtwithoutRenewalPipeline; tscale = sqrt(21.0)) +function remake_latent_model( + inference_config::Dict, pipeline::AbstractRtwithoutRenewalPipeline) #Baseline choices prior_dict = make_model_priors(pipeline) igp = inference_config["igp"] - latent_model_name = inference_config["latent_namemodels"].first - - if latent_model_name == "diff_ar" - if igp == Renewal - ar = AR(damp_priors = [prior_dict["damp_param_prior"]], - std_prior = HalfNormal(0.05 / tscale), - init_priors = [prior_dict["transformed_process_init_prior"]]) - diff_ar = DiffLatentModel(; - model = ar, init_priors = [prior_dict["transformed_process_init_prior"]]) - return diff_ar - elseif igp == ExpGrowthRate - ar = AR(damp_priors = [prior_dict["damp_param_prior"]], - std_prior = HalfNormal(0.005 / tscale), - init_priors = [prior_dict["transformed_process_init_prior"]]) - diff_ar = DiffLatentModel(; - model = ar, init_priors = [prior_dict["transformed_process_init_prior"]]) - return diff_ar - elseif igp == DirectInfections - ar = AR(damp_priors = [Beta(9, 1)], - std_prior = HalfNormal(0.05 / tscale), - init_priors = [prior_dict["transformed_process_init_prior"]]) - diff_ar = DiffLatentModel(; - model = ar, init_priors = [prior_dict["transformed_process_init_prior"]]) - return diff_ar - end - elseif latent_model_name == "ar" - if igp == Renewal - ar = AR(damp_priors = [Beta(2, 8)], - std_prior = HalfNormal(0.25), - init_priors = [prior_dict["transformed_process_init_prior"]]) - return ar - elseif igp == ExpGrowthRate - ar = AR(damp_priors = [prior_dict["damp_param_prior"]], - std_prior = HalfNormal(0.025), - init_priors = [prior_dict["transformed_process_init_prior"]]) - return ar - elseif igp == DirectInfections - ar = AR(damp_priors = [Beta(9, 1)], - std_prior = HalfNormal(0.25), - init_priors = [prior_dict["transformed_process_init_prior"]]) - return ar - end - elseif latent_model_name == "rw" - if igp == Renewal - rw = RandomWalk( - std_prior = HalfNormal(0.05 / tscale), - init_prior = prior_dict["transformed_process_init_prior"]) - return rw - elseif igp == ExpGrowthRate - rw = RandomWalk( - std_prior = HalfNormal(0.005 / tscale), - init_prior = prior_dict["transformed_process_init_prior"]) - return rw - elseif igp == DirectInfections - rw = RandomWalk( - std_prior = HalfNormal(0.1 / tscale), - init_prior = prior_dict["transformed_process_init_prior"]) - return rw - end - end + default_latent_model = inference_config["latent_namemodels"].second + target_std, target_autocorr = default_latent_model isa AR ? + _make_target_std_and_autocorr(igp; stationary = true) : + _make_target_std_and_autocorr(igp; stationary = false) + + return _implement_latent_process( + target_std, target_autocorr, default_latent_model, pipeline) +end + +""" +This function sets the target standard deviation for an infection generating process (igp) +based on whether the latent process representation of its dynamics are stationary or non-stationary. + +## Stationary Processes + +- For Renewal process `log(R_t)` in the long run a fluctuation of 0.75 (e.g. ~ 75% of the mean) is not unexpected. +- For Exponential Growth Rate process `r_t` in the long run a fluctuation of 0.2 is not unexpected e.g. going from +`rt = 0.1` (7 day doubling time) to `rt = -0.1` (7 day halving time) is a 0.2 time-to-time fluctuation. +- For Direct Infections process `log(I_t)` in the long run a fluctuation of 2.0 (i.e a couple of orders of magnitude) is not unexpected. + +For stationary latent processes Direct Infections and rt processes the autocorrelation is expected to be high at 0.9, +because persistence in residual away from mean is expected. Otherwise, the autocorrelation is expected to be 0.1. + +## Non-Stationary Processes + +For Renewal process `log(R_t)` in a single time step a fluctuation of 0.025 (e.g. ~ 2.5% of the mean) is not unexpected. +For Exponential Growth Rate process `r_t` in a single time step a fluctuation of 0.005 is not unexpected. +For Direct Infections process `log(I_t)` in a single time step a fluctuation of 0.025 is not unexpected. + +The autocorrelation is expected to be 0.1. +""" +function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool) + return stationary ? (0.75, 0.1) : (0.025, 0.1) +end + +function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool) + return stationary ? (0.2, 0.9) : (0.005, 0.1) +end + +function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool) + return stationary ? (2.0, 0.9) : (0.25, 0.1) +end + +function _make_new_prior_dict(target_std, target_autocorr, + pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size) + #Get default priors + prior_dict = make_model_priors(pipeline) + #Adjust priors based on target autocorrelation and standard deviation + damp_prior = Beta(target_autocorr * beta_eff_sample_size, + (1 - target_autocorr) * beta_eff_sample_size) + corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2)) + noise_prior = HalfNormal(target_std) + init_prior = prior_dict["transformed_process_init_prior"] + return Dict( + "transformed_process_init_prior" => init_prior, + "corr_corrected_noise_prior" => corr_corrected_noise_prior, + "noise_prior" => noise_prior, + "damp_param_prior" => damp_prior + ) +end + +""" +Constructs and returns a latent model based on an approximation to the specified target standard deviation and autocorrelation. + +NB: The stationary variance of an AR(1) process is given by `σ² = σ²_ε / (1 - ρ²)` where `σ²_ε` is the variance of the noise and `ρ` is the autocorrelation. +The approximation here are based on `E[1/(1 - ρ²)`] ≈ 1 / (1 - E[ρ²])` which only holds for fairly tight distributions of `ρ`. +However, for priors this should get the expected order of magnitude. + +# Models +- `"diff_ar"`: Constructs a `DiffLatentModel` with an autoregressive (AR) process. +- `"ar"`: Constructs an autoregressive (AR) process. +- `"rw"`: Constructs a random walk (RW) process. + +""" +function _implement_latent_process( + target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10) + prior_dict = make_model_priors(pipeline) + new_priors = _make_new_prior_dict( + target_std, target_autocorr, pipeline; beta_eff_sample_size) + + return _make_latent(default_latent_model, new_priors) +end + +function _make_latent(::AR, new_priors) + damp_prior = new_priors["damp_param_prior"] + corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"] + init_prior = new_priors["transformed_process_init_prior"] + return AR(damp_priors = [damp_prior], + std_prior = corr_corrected_noise_std, + init_priors = [init_prior]) +end + +function _make_latent(::DiffLatentModel, new_priors) + init_prior = new_priors["transformed_process_init_prior"] + ar = _make_latent(AR(), new_priors) + return DiffLatentModel(; model = ar, init_priors = [init_prior]) +end + +function _make_latent(::RandomWalk, new_priors) + noise_std = new_priors["noise_prior"] + init_prior = new_priors["transformed_process_init_prior"] + return RandomWalk(std_prior = noise_std, init_prior = init_prior) end """ diff --git a/pipeline/test/constructors/remake_latent_model.jl b/pipeline/test/constructors/remake_latent_model.jl index 305eb9543..9f6d9f5f3 100644 --- a/pipeline/test/constructors/remake_latent_model.jl +++ b/pipeline/test/constructors/remake_latent_model.jl @@ -7,49 +7,51 @@ ) end pipeline = MockPipeline() - + ar = AR() + diff_ar = DiffLatentModel(model = ar) + rw = RandomWalk() @testset "diff_ar model" begin inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => ("diff_ar" => "diff_ar")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("diff_ar", diff_ar)) model = remake_latent_model(inference_config, pipeline) @test model isa DiffLatentModel @test model.model isa AR inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => ("diff_ar" => "diff_ar")) + "igp" => DirectInfections, "latent_namemodels" => Pair("diff_ar", diff_ar)) model = remake_latent_model(inference_config, pipeline) @test model isa DiffLatentModel @test model.model isa AR end @testset "ar model" begin - inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", "ar")) + inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", "ar")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => Pair("ar", "ar")) + "igp" => DirectInfections, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR end @testset "rw model" begin - inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", "rw")) + inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", "rw")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => Pair("rw", "rw")) + "igp" => DirectInfections, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk end