diff --git a/pipeline/scripts/create_figure1.jl b/pipeline/scripts/create_figure1.jl index 46fbb39f4..a33d15aa7 100644 --- a/pipeline/scripts/create_figure1.jl +++ b/pipeline/scripts/create_figure1.jl @@ -32,7 +32,7 @@ latent_model_dict = Dict( ## `ar` is the default latent model which we show as figure 1, others are for SI -_ = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model +figs = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model map(Iterators.product(gi_means, gi_means)) do (true_gi_choice, used_gi_choice) fig = figureone( prediction_df, truth_data_df, scenarios, targets; scenario_dict, target_dict, diff --git a/pipeline/scripts/create_figure2.jl b/pipeline/scripts/create_figure2.jl index f65cde1b1..903aaa218 100644 --- a/pipeline/scripts/create_figure2.jl +++ b/pipeline/scripts/create_figure2.jl @@ -1,25 +1,16 @@ -## Script to make figure 2 and alternate latent models for SI -using Pkg -Pkg.activate(joinpath(@__DIR__(), "..")) - using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta, - Statistics, Distributions, CSV + Statistics, Distributions, CSV, CairoMakie -## -pipelines = [ - SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), - SmoothEndemicPipeline(), RoughEndemicPipeline()] +## Define scenarios and targets +scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] +targets = ["log_I_t", "rt", "Rt"] +gi_means = [2.0, 10.0, 20.0] ## load some data and create a dataframe for the plot -truth_data_files = readdir(datadir("truth_data")) |> - strs -> filter(s -> occursin("jld2", s), strs) -analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame -truth_df = mapreduce(vcat, truth_data_files) do filename - D = JLD2.load(joinpath(datadir("truth_data"), filename)) - make_truthdata_dataframe(filename, D, pipelines) -end +truth_data_df = CSV.File(plotsdir("plotting_data/truthdata.csv")) |> DataFrame +prediction_df = CSV.File(plotsdir("plotting_data/predictions.csv")) |> DataFrame -# Define scenario titles and reference times for figure 2 +# Define scenario titles and reference times for figure 1 scenario_dict = Dict( "measures_outbreak" => (title = "Outbreak with measures", T = 28), "smooth_outbreak" => (title = "Outbreak no measures", T = 35), @@ -28,23 +19,66 @@ scenario_dict = Dict( ) target_dict = Dict( - "log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6), ord = 1), - "rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1), ord = 2), - "Rt" => (title = "Reproductive number", ylims = (-0.1, 3), ord = 3) + "log_I_t" => (title = "log(Incidence)",), + "rt" => (title = "Exp. growth rate",), + "Rt" => (title = "Reproductive number",) ) latent_model_dict = Dict( - "wkly_rw" => (title = "Random walk",), - "wkly_ar" => (title = "AR(1)",), - "wkly_diff_ar" => (title = "Diff. AR(1)",) + "rw" => (title = "Random walk",), + "ar" => (title = "AR(1)",), + "diff_ar" => (title = "Diff. AR(1)",) ) ## +# **Fig 2**: _Overview_: This fig aims at presenting the nowcasting (e.g. 0 horizon estimate) +# at rolling inference time points for each scenario with each inference model choice _and_ +# possible misspecification of generation interval. Time horizon choice: Chosen horizon = 0 +# to align with Fig 1 but with other horizons as SI plots. _Plotting details:_ 3 x 4 = 12 rows +# corresponding to 4 main scenarios (e.g. outbreak with measures etc.) and 3 main targets (e.g. +# exponential growth rate etc), the scenario GI is fixed to the middle mean GI (10 days; +# others are in SI) and 3 columns corresponding to _underestimating mean GI_ (left), good +# estimation of GI (middle) and over estimating mean GI (right). Actual values as scatter plot. +# The posterior inferred value at the estimation date_ are plotted as boxplot plot quantiles +# with colour determining the inference model. + +df = EpiAwarePipeline._fig2_pred_filter(prediction_df, "smooth_outbreak", "log_I_t", "ar", + 0; true_gi_choice = 10.0, used_gi_choice = 10.0) +truth_df = EpiAwarePipeline._fig_truth_filter( + truth_data_df, "smooth_outbreak", "log_I_t"; true_gi_choice = 10.0) +fig = Figure() +ax = Axis(fig[1, 1]) +EpiAwarePipeline._plot_predictions!( + ax, df; igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + colors = [:red, :blue, :green], iqr_alpha = 0.3) +EpiAwarePipeline._plot_truth!(ax, truth_df; color = :black) +vlines!(ax, df.Reference_Time |> unique, color = :black, linestyle = :dash) +ax.limits = ((minimum(df.Reference_Time) - 7, maximum(df.Reference_Time) + 1), nothing) +ax.xticks = vcat(minimum(df.Reference_Time) - 7, df.Reference_Time |> unique) +fig + +# figs = mapreduce(vcat, scenarios) do scenario +# mapreduce(gi_means) do true_gi_choice +# fig = figuretwo( +# truth_data_df, prediction_df, "ar", scenario_dict, target_dict; +# true_gi_choice = true_gi_choice) +# save(plotsdir("figure2_$(scenario)_trueGI_$(true_gi_choice).png"), fig) +# end +# end + +## + +## `ar` is the default latent model which we show as figure 1, others are for SI + +figs = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model + fig = figuretwo( + prediction_df, truth_data_df, scenarios, targets, 0; + scenario_dict, target_dict, latent_model_dict, + latent_model, igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + true_gi_choice = 10.0, other_gi_choices = [2.0, 10.0, 20.0], data_color = :black, + colors = [:red, :blue, :green], iqr_alpha = 0.3, horizon_diff = 7) -fig = figuretwo( - truth_df, analysis_df, "Renewal", scenario_dict, target_dict) -_ = map(analysis_df.IGP_Model |> unique) do igp - fig = figureone( - truth_df, analysis_df, latent_model, scenario_dict, target_dict, latent_model_dict) - save(plotsdir("figure2_$(igp).png"), fig) + save( + plotsdir("figure2_$(latent_model).png"), + fig) end diff --git a/pipeline/src/plotting/figureone.jl b/pipeline/src/plotting/figureone.jl index 1a040bd15..d0536cc2b 100644 --- a/pipeline/src/plotting/figureone.jl +++ b/pipeline/src/plotting/figureone.jl @@ -1,76 +1,18 @@ """ -Plot predictions on the given axis (`ax`) based on the provided parameters. - -# Arguments -- `ax`: The axis on which to plot the predictions. -- `predictions`: DataFrame containing the prediction data. -- `scenario`: The scenario to filter the predictions. -- `target`: The target to filter the predictions. -- `reference_time`: The reference time to filter the predictions. -- `latent_model`: The latent model to filter the predictions. -- `igps`: A list of IGP models to plot. Default is `["DirectInfections", "ExpGrowthRate", "Renewal"]`. -- `true_gi_choice`: The true generation interval mean to filter the predictions. Default is `2.0`. -- `used_gi_choice`: The used generation interval mean to filter the predictions. Default is `2.0`. -- `colors`: A list of colors for each IGP model. Default is `[:red, :blue, :green]`. -- `iqr_alpha`: The alpha value for the interquartile range bands. Default is `0.3`. - -# Description -This function filters the `predictions` DataFrame based on the provided parameters and plots - the predictions on the given axis (`ax`). It plots the median prediction line and two - bands representing the interquartile range (IQR) and the 95% prediction interval for - each IGP model specified in `igps`. - -""" -function _plot_predictions!( - ax, predictions, scenario, target, reference_time, latent_model; - igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], - true_gi_choice = 2.0, used_gi_choice = 2.0, colors = [:red, :blue, :green], - iqr_alpha = 0.3) - pred = predictions |> - df -> @subset(df, :Latent_Model.==latent_model) |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> - df -> @subset(df, :Reference_Time.==reference_time) |> - df -> @subset(df, :Scenario.==scenario) |> - df -> @subset(df, :Target.==target) - for (c, igp) in zip(colors, igps) - x = pred[pred.IGP_Model .== igp, "target_times"] - y = pred[pred.IGP_Model .== igp, "q_5"] - upr1 = pred[pred.IGP_Model .== igp, "q_75"] - upr2 = pred[pred.IGP_Model .== igp, "q_975"] - lwr1 = pred[pred.IGP_Model .== igp, "q_25"] - lwr2 = pred[pred.IGP_Model .== igp, "q_025"] - if length(x) > 0 - lines!(ax, x, y, color = c, label = igp, linewidth = 3) - band!(ax, x, lwr1, upr1, color = (c, iqr_alpha)) - band!(ax, x, lwr2, upr2, color = (c, iqr_alpha / 2)) - end - end - return nothing -end - +Filter the `predictions` DataFrame for `scenario`, `target`, `reference_time`, + `latent_model`, `true_gi_choice`, and `used_gi_choice`. This is aimed at generating + facets for figure 1. """ -Plot the truth data on the given axis. - -# Arguments -- `ax`: The axis to plot on. -- `truth`: The DataFrame containing the truth data. -- `scenario`: The scenario to filter the truth data by. -- `target`: The target to filter the truth data by. -- `true_gi_choice`: The true generation interval choice to filter the truth data by (default is 2.0). -- `color`: The color of the scatter plot (default is :black). - -""" -function _plot_truth!(ax, truth, scenario, target; true_gi_choice = 2.0, color = :black) - pred = truth |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Scenario.==scenario) |> - df -> @subset(df, :Target.==target) - x = pred[!, "target_times"] - y = pred[!, "target_values"] - scatter!(ax, x, y, color = color, label = "Data") - - return nothing +function _fig1_pred_filter(predictions, scenario, target, reference_time, + latent_model; true_gi_choice = 2.0, used_gi_choice = 2.0) + df = predictions |> + df -> @subset(df, :Latent_Model.==latent_model) |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> + df -> @subset(df, :Reference_Time.==reference_time) |> + df -> @subset(df, :Scenario.==scenario) |> + df -> @subset(df, :Target.==target) + return df end """ @@ -103,13 +45,16 @@ function figureone( axs = mapreduce(hcat, enumerate(targets)) do (i, target) map(enumerate(scenarios)) do (j, scenario) ax = Axis(fig[i, j]) - _plot_predictions!( - ax, prediction_df, scenario, target, scenario_dict[scenario].T, - latent_model; true_gi_choice, used_gi_choice, colors, iqr_alpha, igps) - _plot_truth!( - ax, truth_data_df, scenario, target; true_gi_choice, color = data_color) + #Filter the data for fig1 panels + pred_df = _fig1_pred_filter( + prediction_df, scenario, target, scenario_dict[scenario].T, + latent_model; true_gi_choice, used_gi_choice) + truth_df = _fig_truth_filter(truth_data_df, scenario, target; true_gi_choice) + #Plot onto axes + _plot_predictions!(ax, pred_df; igps, colors, iqr_alpha) + _plot_truth!(ax, truth_df; color = data_color) vlines!(ax, [scenario_dict[scenario].T], color = data_color, - linewidth = 3, label = "Horizon") + linewidth = 3, label = "Reference time") if i == 1 ax.title = scenario_dict[scenario].title end diff --git a/pipeline/src/plotting/figuretwo.jl b/pipeline/src/plotting/figuretwo.jl index aeb5a911f..2d86f1612 100644 --- a/pipeline/src/plotting/figuretwo.jl +++ b/pipeline/src/plotting/figuretwo.jl @@ -1,91 +1,75 @@ -function _make_captions!(df, scenario_dict, target_dict) - scenario_titles = [scenario_dict[scenario].title for scenario in df.Scenario] - target_titles = [target_dict[target].title for target in df.Target] - df.Scenario_Target .= scenario_titles .* "\n" .* target_titles - return nothing +function _fig2_pred_filter(predictions, scenario, target, latent_model, horizon; + true_gi_choice, used_gi_choice, horizon_diff = 7) + df = predictions |> + df -> @subset(df, :Latent_Model.==latent_model) |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> + df -> @subset(df, + horizon-horizon_diff.<(:target_times.-:Reference_Time).<=horizon) |> + df -> @subset(df, :Scenario.==scenario) |> + df -> @subset(df, :Target.==target) + return df end -function _figure_two_truth_data( - truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices = [ - 2.0, 10.0, 20.0]) - _truth_df = mapreduce(vcat, gi_choices) do used_gi - df = deepcopy(truth_df) - df.Used_GI_Mean .= used_gi - df +function figuretwo( + prediction_df, truth_data_df, scenarios, targets, horizon; + scenario_dict, target_dict, latent_model_dict, + latent_model = "ar", igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + true_gi_choice = 10.0, other_gi_choices = [2.0, 10.0, 20.0], data_color = :black, + colors = [:red, :blue, :green], iqr_alpha = 0.3, horizon_diff = 7) + fig = Figure(; size = (1000, 800 * length(scenarios))) + axs = mapreduce(vcat, enumerate(scenarios)) do (i, scenario) + n = length(targets) + Label(fig[(n * (i - 1) + 1):(n * i), 0], + scenario_dict[scenario].title, rotation = pi / 2, fontsize = 36) + mapreduce(hcat, enumerate(targets)) do (j, target) + map(enumerate(other_gi_choices)) do (k, used_gi_choice) + row = j + (i - 1) * length(targets) + ax = Axis(fig[row, k]) + # #Filter the data for fig2 panels + pred_df = _fig2_pred_filter( + prediction_df, scenario, target, latent_model, horizon; + true_gi_choice, used_gi_choice, horizon_diff) + truth_df = _fig_truth_filter( + truth_data_df, scenario, target; true_gi_choice) + # #Plot onto axes + _plot_predictions!(ax, pred_df; igps, colors, iqr_alpha) + _plot_truth!(ax, truth_df; color = data_color) + vlines!(ax, pred_df.Reference_Time |> unique, color = :black, + linestyle = :dash, label = "Reference time") + # axes + if row == 1 + if k == 1 + ax.title = "Underestimating mean GI" + elseif k == 2 + ax.title = "Good estimation of GI" + elseif k == 3 + ax.title = "Overestimating mean GI" + end + end + if row == length(targets) * length(scenarios) + ax.xlabel = "Time" + end + if k == 1 + ax.ylabel = target_dict[target].title + end + ax.limits = ( + (minimum(pred_df.Reference_Time) - horizon_diff, + maximum(pred_df.Reference_Time) + 1), + nothing) + ax.xticks = vcat(minimum(pred_df.Reference_Time) - horizon_diff, + pred_df.Reference_Time |> unique) + ax + end + end end - _make_captions!(_truth_df, scenario_dict, target_dict) - - truth_plotting_data = _truth_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @transform(df, :Data="Truth data") |> data - plt_truth = truth_plotting_data * - mapping(:target_times => "T", :target_values => "Process values", - row = :Scenario_Target, - col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI", - 10.0 => "Good GI", 20.0 => "Overestimate GI"]), - color = :Data => AlgebraOfGraphics.scale(:color2)) * - visual(AlgebraOfGraphics.Scatter) - return plt_truth -end - -function _figure_two_scenario( - analysis_df, igp, scenario_dict, target_dict; true_gi_choice, - lower_sym = :q_025, upper_sym = :q_975) - min_ref_time = minimum(analysis_df.Reference_Time) - early_df = analysis_df |> - df -> @subset(df, :Reference_Time.==min_ref_time) |> - df -> @subset(df, :IGP_Model.==igp) |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :target_times.<=min_ref_time - 7) - - seqn_df = analysis_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :IGP_Model.==igp) |> - df -> @subset(df, - :Reference_Time .- :target_times.∈fill(0:6, size(df, 1))) - - full_df = vcat(early_df, seqn_df) - _make_captions!(full_df, scenario_dict, target_dict) - - model_plotting_data = full_df |> data - - plt_model = model_plotting_data * - mapping(:target_times => "T", :q_5 => "Process values", - row = :Scenario_Target, - col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI", - 10.0 => "Good GI", 20.0 => "Overestimate GI"]), - color = :Latent_Model => "Latent models") * - mapping(lower = lower_sym, upper = upper_sym) * - visual(LinesFill) - - return plt_model -end - -function figuretwo(truth_df, analysis_df, igp, scenario_dict, - target_dict; fig_kws = (; size = (1000, 2800)), - true_gi_choice = 10.0, gi_choices = [2.0, 10.0, 20.0]) - - # Perform checks on the dataframes - _dataframe_checks(truth_df, analysis_df, scenario_dict) - - f_td = _figure_two_truth_data( - truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices) - f_mdl = _figure_two_scenario( - analysis_df, igp, scenario_dict, target_dict; true_gi_choice) - - fg = draw(f_mdl + f_td; facet = (; linkyaxes = :none), - legend = (; orientation = :horizontal, position = :bottom), - figure = fig_kws, - axis = (; xlabel = "T", ylabel = "Process values")) - for g in fg.grid[1:3:end, :] - g.axis.limits = (nothing, target_dict["rt"].ylims) - end - for g in fg.grid[2:3:end, :] - g.axis.limits = (nothing, target_dict["Rt"].ylims) - end - for g in fg.grid[3:3:end, :] - g.axis.limits = (nothing, target_dict["log_I_t"].ylims) - end - - return fg + leg = Legend(fig[length(targets) * length(scenarios) + 1, 1:2], + last(axs), "Infection generating process"; + orientation = :horizontal, tellwidth = false, framevisible = false) + lab = Label(fig[length(targets) * length(scenarios) + 1, length(other_gi_choices)], + "Latent model for \n infection generating\n process: $(latent_model_dict[latent_model].title) \n True mean GI: $(true_gi_choice) days \n Horizon: $(horizon) days"; + tellwidth = false, + fontsize = 18) + resize_to_layout!(fig) + fig end diff --git a/pipeline/src/plotting/panel_plots.jl b/pipeline/src/plotting/panel_plots.jl new file mode 100644 index 000000000..aaaa678fc --- /dev/null +++ b/pipeline/src/plotting/panel_plots.jl @@ -0,0 +1,73 @@ +""" +Plot predictions on the given axis (`ax`) based on the provided parameters. + +# Arguments +- `ax`: The axis on which to plot the predictions. +- `predictions`: DataFrame containing the prediction data. +- `scenario`: The scenario to filter the predictions. +- `target`: The target to filter the predictions. +- `reference_time`: The reference time to filter the predictions. +- `latent_model`: The latent model to filter the predictions. +- `igps`: A list of IGP models to plot. Default is `["DirectInfections", "ExpGrowthRate", "Renewal"]`. +- `true_gi_choice`: The true generation interval mean to filter the predictions. Default is `2.0`. +- `used_gi_choice`: The used generation interval mean to filter the predictions. Default is `2.0`. +- `colors`: A list of colors for each IGP model. Default is `[:red, :blue, :green]`. +- `iqr_alpha`: The alpha value for the interquartile range bands. Default is `0.3`. + +# Description +This function filters the `predictions` DataFrame based on the provided parameters and plots + the predictions on the given axis (`ax`). It plots the median prediction line and two + bands representing the interquartile range (IQR) and the 95% prediction interval for + each IGP model specified in `igps`. + +""" +function _plot_predictions!( + ax, pred; igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + colors = [:red, :blue, :green], iqr_alpha = 0.3) + for (c, igp) in zip(colors, igps) + x = pred[pred.IGP_Model .== igp, "target_times"] + y = pred[pred.IGP_Model .== igp, "q_5"] + upr1 = pred[pred.IGP_Model .== igp, "q_75"] + upr2 = pred[pred.IGP_Model .== igp, "q_975"] + lwr1 = pred[pred.IGP_Model .== igp, "q_25"] + lwr2 = pred[pred.IGP_Model .== igp, "q_025"] + if length(x) > 0 + lines!(ax, x, y, color = c, label = igp, linewidth = 3) + band!(ax, x, lwr1, upr1, color = (c, iqr_alpha)) + band!(ax, x, lwr2, upr2, color = (c, iqr_alpha / 2)) + end + end + return nothing +end + +""" +Plot the truth data on the given axis. + +# Arguments +- `ax`: The axis to plot on. +- `truth`: The DataFrame containing the truth data. +- `scenario`: The scenario to filter the truth data by. +- `target`: The target to filter the truth data by. +- `true_gi_choice`: The true generation interval choice to filter the truth data by (default is 2.0). +- `color`: The color of the scatter plot (default is :black). + +""" +function _plot_truth!(ax, truth; color = :black) + x = truth[!, "target_times"] + y = truth[!, "target_values"] + scatter!(ax, x, y, color = color, label = "Data") + + return nothing +end + +""" +Filter the `truth` DataFrame for `scenario`, `target`, `latent_model`, `true_gi_choice`, + and `used_gi_choice`. This is aimed at generating facets for figure 1. +""" +function _fig_truth_filter(truth, scenario, target; true_gi_choice) + df = truth |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Scenario.==scenario) |> + df -> @subset(df, :Target.==target) + return df +end diff --git a/pipeline/src/plotting/plotting.jl b/pipeline/src/plotting/plotting.jl index 5e594171f..ce6bdbcf0 100644 --- a/pipeline/src/plotting/plotting.jl +++ b/pipeline/src/plotting/plotting.jl @@ -1,5 +1,6 @@ include("basicplots.jl") include("df_checking.jl") +include("panel_plots.jl") include("figureone.jl") include("figuretwo.jl") include("prior_predictive_plot.jl")