Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Figure 2 script #558

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pipeline/scripts/create_figure1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ latent_model_dict = Dict(

## `ar` is the default latent model which we show as figure 1, others are for SI

_ = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model
figs = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model
map(Iterators.product(gi_means, gi_means)) do (true_gi_choice, used_gi_choice)
fig = figureone(
prediction_df, truth_data_df, scenarios, targets; scenario_dict, target_dict,
Expand Down
92 changes: 63 additions & 29 deletions pipeline/scripts/create_figure2.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
## Script to make figure 2 and alternate latent models for SI
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))

using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta,
Statistics, Distributions, CSV
Statistics, Distributions, CSV, CairoMakie

##
pipelines = [
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(),
SmoothEndemicPipeline(), RoughEndemicPipeline()]
## Define scenarios and targets
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]
targets = ["log_I_t", "rt", "Rt"]
gi_means = [2.0, 10.0, 20.0]

## load some data and create a dataframe for the plot
truth_data_files = readdir(datadir("truth_data")) |>
strs -> filter(s -> occursin("jld2", s), strs)
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame
truth_df = mapreduce(vcat, truth_data_files) do filename
D = JLD2.load(joinpath(datadir("truth_data"), filename))
make_truthdata_dataframe(filename, D, pipelines)
end
truth_data_df = CSV.File(plotsdir("plotting_data/truthdata.csv")) |> DataFrame
prediction_df = CSV.File(plotsdir("plotting_data/predictions.csv")) |> DataFrame

# Define scenario titles and reference times for figure 2
# Define scenario titles and reference times for figure 1
scenario_dict = Dict(
"measures_outbreak" => (title = "Outbreak with measures", T = 28),
"smooth_outbreak" => (title = "Outbreak no measures", T = 35),
Expand All @@ -28,23 +19,66 @@ scenario_dict = Dict(
)

target_dict = Dict(
"log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6), ord = 1),
"rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1), ord = 2),
"Rt" => (title = "Reproductive number", ylims = (-0.1, 3), ord = 3)
"log_I_t" => (title = "log(Incidence)",),
"rt" => (title = "Exp. growth rate",),
"Rt" => (title = "Reproductive number",)
)

latent_model_dict = Dict(
"wkly_rw" => (title = "Random walk",),
"wkly_ar" => (title = "AR(1)",),
"wkly_diff_ar" => (title = "Diff. AR(1)",)
"rw" => (title = "Random walk",),
"ar" => (title = "AR(1)",),
"diff_ar" => (title = "Diff. AR(1)",)
)

##
# **Fig 2**: _Overview_: This fig aims at presenting the nowcasting (e.g. 0 horizon estimate)
# at rolling inference time points for each scenario with each inference model choice _and_
# possible misspecification of generation interval. Time horizon choice: Chosen horizon = 0
# to align with Fig 1 but with other horizons as SI plots. _Plotting details:_ 3 x 4 = 12 rows
# corresponding to 4 main scenarios (e.g. outbreak with measures etc.) and 3 main targets (e.g.
# exponential growth rate etc), the scenario GI is fixed to the middle mean GI (10 days;
# others are in SI) and 3 columns corresponding to _underestimating mean GI_ (left), good
# estimation of GI (middle) and over estimating mean GI (right). Actual values as scatter plot.
# The posterior inferred value at the estimation date_ are plotted as boxplot plot quantiles
# with colour determining the inference model.

df = EpiAwarePipeline._fig2_pred_filter(prediction_df, "smooth_outbreak", "log_I_t", "ar",
0; true_gi_choice = 10.0, used_gi_choice = 10.0)
truth_df = EpiAwarePipeline._fig_truth_filter(
truth_data_df, "smooth_outbreak", "log_I_t"; true_gi_choice = 10.0)
fig = Figure()
ax = Axis(fig[1, 1])
EpiAwarePipeline._plot_predictions!(
ax, df; igps = ["DirectInfections", "ExpGrowthRate", "Renewal"],
colors = [:red, :blue, :green], iqr_alpha = 0.3)
EpiAwarePipeline._plot_truth!(ax, truth_df; color = :black)
vlines!(ax, df.Reference_Time |> unique, color = :black, linestyle = :dash)
ax.limits = ((minimum(df.Reference_Time) - 7, maximum(df.Reference_Time) + 1), nothing)
ax.xticks = vcat(minimum(df.Reference_Time) - 7, df.Reference_Time |> unique)
fig

# figs = mapreduce(vcat, scenarios) do scenario
# mapreduce(gi_means) do true_gi_choice
# fig = figuretwo(
# truth_data_df, prediction_df, "ar", scenario_dict, target_dict;
# true_gi_choice = true_gi_choice)
# save(plotsdir("figure2_$(scenario)_trueGI_$(true_gi_choice).png"), fig)
# end
# end

##

## `ar` is the default latent model which we show as figure 1, others are for SI

figs = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model
fig = figuretwo(
prediction_df, truth_data_df, scenarios, targets, 0;
scenario_dict, target_dict, latent_model_dict,
latent_model, igps = ["DirectInfections", "ExpGrowthRate", "Renewal"],
true_gi_choice = 10.0, other_gi_choices = [2.0, 10.0, 20.0], data_color = :black,
colors = [:red, :blue, :green], iqr_alpha = 0.3, horizon_diff = 7)

fig = figuretwo(
truth_df, analysis_df, "Renewal", scenario_dict, target_dict)
_ = map(analysis_df.IGP_Model |> unique) do igp
fig = figureone(
truth_df, analysis_df, latent_model, scenario_dict, target_dict, latent_model_dict)
save(plotsdir("figure2_$(igp).png"), fig)
save(
plotsdir("figure2_$(latent_model).png"),
fig)
end
99 changes: 22 additions & 77 deletions pipeline/src/plotting/figureone.jl
Original file line number Diff line number Diff line change
@@ -1,76 +1,18 @@
"""
Plot predictions on the given axis (`ax`) based on the provided parameters.

# Arguments
- `ax`: The axis on which to plot the predictions.
- `predictions`: DataFrame containing the prediction data.
- `scenario`: The scenario to filter the predictions.
- `target`: The target to filter the predictions.
- `reference_time`: The reference time to filter the predictions.
- `latent_model`: The latent model to filter the predictions.
- `igps`: A list of IGP models to plot. Default is `["DirectInfections", "ExpGrowthRate", "Renewal"]`.
- `true_gi_choice`: The true generation interval mean to filter the predictions. Default is `2.0`.
- `used_gi_choice`: The used generation interval mean to filter the predictions. Default is `2.0`.
- `colors`: A list of colors for each IGP model. Default is `[:red, :blue, :green]`.
- `iqr_alpha`: The alpha value for the interquartile range bands. Default is `0.3`.

# Description
This function filters the `predictions` DataFrame based on the provided parameters and plots
the predictions on the given axis (`ax`). It plots the median prediction line and two
bands representing the interquartile range (IQR) and the 95% prediction interval for
each IGP model specified in `igps`.

"""
function _plot_predictions!(
ax, predictions, scenario, target, reference_time, latent_model;
igps = ["DirectInfections", "ExpGrowthRate", "Renewal"],
true_gi_choice = 2.0, used_gi_choice = 2.0, colors = [:red, :blue, :green],
iqr_alpha = 0.3)
pred = predictions |>
df -> @subset(df, :Latent_Model.==latent_model) |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |>
df -> @subset(df, :Reference_Time.==reference_time) |>
df -> @subset(df, :Scenario.==scenario) |>
df -> @subset(df, :Target.==target)
for (c, igp) in zip(colors, igps)
x = pred[pred.IGP_Model .== igp, "target_times"]
y = pred[pred.IGP_Model .== igp, "q_5"]
upr1 = pred[pred.IGP_Model .== igp, "q_75"]
upr2 = pred[pred.IGP_Model .== igp, "q_975"]
lwr1 = pred[pred.IGP_Model .== igp, "q_25"]
lwr2 = pred[pred.IGP_Model .== igp, "q_025"]
if length(x) > 0
lines!(ax, x, y, color = c, label = igp, linewidth = 3)
band!(ax, x, lwr1, upr1, color = (c, iqr_alpha))
band!(ax, x, lwr2, upr2, color = (c, iqr_alpha / 2))
end
end
return nothing
end

Filter the `predictions` DataFrame for `scenario`, `target`, `reference_time`,
`latent_model`, `true_gi_choice`, and `used_gi_choice`. This is aimed at generating
facets for figure 1.
"""
Plot the truth data on the given axis.

# Arguments
- `ax`: The axis to plot on.
- `truth`: The DataFrame containing the truth data.
- `scenario`: The scenario to filter the truth data by.
- `target`: The target to filter the truth data by.
- `true_gi_choice`: The true generation interval choice to filter the truth data by (default is 2.0).
- `color`: The color of the scatter plot (default is :black).

"""
function _plot_truth!(ax, truth, scenario, target; true_gi_choice = 2.0, color = :black)
pred = truth |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Scenario.==scenario) |>
df -> @subset(df, :Target.==target)
x = pred[!, "target_times"]
y = pred[!, "target_values"]
scatter!(ax, x, y, color = color, label = "Data")

return nothing
function _fig1_pred_filter(predictions, scenario, target, reference_time,
latent_model; true_gi_choice = 2.0, used_gi_choice = 2.0)
df = predictions |>
df -> @subset(df, :Latent_Model.==latent_model) |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |>
df -> @subset(df, :Reference_Time.==reference_time) |>
df -> @subset(df, :Scenario.==scenario) |>
df -> @subset(df, :Target.==target)
return df
end

"""
Expand Down Expand Up @@ -103,13 +45,16 @@ function figureone(
axs = mapreduce(hcat, enumerate(targets)) do (i, target)
map(enumerate(scenarios)) do (j, scenario)
ax = Axis(fig[i, j])
_plot_predictions!(
ax, prediction_df, scenario, target, scenario_dict[scenario].T,
latent_model; true_gi_choice, used_gi_choice, colors, iqr_alpha, igps)
_plot_truth!(
ax, truth_data_df, scenario, target; true_gi_choice, color = data_color)
#Filter the data for fig1 panels
pred_df = _fig1_pred_filter(
prediction_df, scenario, target, scenario_dict[scenario].T,
latent_model; true_gi_choice, used_gi_choice)
truth_df = _fig_truth_filter(truth_data_df, scenario, target; true_gi_choice)
#Plot onto axes
_plot_predictions!(ax, pred_df; igps, colors, iqr_alpha)
_plot_truth!(ax, truth_df; color = data_color)
vlines!(ax, [scenario_dict[scenario].T], color = data_color,
linewidth = 3, label = "Horizon")
linewidth = 3, label = "Reference time")
if i == 1
ax.title = scenario_dict[scenario].title
end
Expand Down
158 changes: 71 additions & 87 deletions pipeline/src/plotting/figuretwo.jl
Original file line number Diff line number Diff line change
@@ -1,91 +1,75 @@
function _make_captions!(df, scenario_dict, target_dict)
scenario_titles = [scenario_dict[scenario].title for scenario in df.Scenario]
target_titles = [target_dict[target].title for target in df.Target]
df.Scenario_Target .= scenario_titles .* "\n" .* target_titles
return nothing
function _fig2_pred_filter(predictions, scenario, target, latent_model, horizon;
true_gi_choice, used_gi_choice, horizon_diff = 7)
df = predictions |>
df -> @subset(df, :Latent_Model.==latent_model) |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |>
df -> @subset(df,
horizon-horizon_diff.<(:target_times.-:Reference_Time).<=horizon) |>
df -> @subset(df, :Scenario.==scenario) |>
df -> @subset(df, :Target.==target)
return df
end

function _figure_two_truth_data(
truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices = [
2.0, 10.0, 20.0])
_truth_df = mapreduce(vcat, gi_choices) do used_gi
df = deepcopy(truth_df)
df.Used_GI_Mean .= used_gi
df
function figuretwo(
prediction_df, truth_data_df, scenarios, targets, horizon;
scenario_dict, target_dict, latent_model_dict,
latent_model = "ar", igps = ["DirectInfections", "ExpGrowthRate", "Renewal"],
true_gi_choice = 10.0, other_gi_choices = [2.0, 10.0, 20.0], data_color = :black,
colors = [:red, :blue, :green], iqr_alpha = 0.3, horizon_diff = 7)
fig = Figure(; size = (1000, 800 * length(scenarios)))
axs = mapreduce(vcat, enumerate(scenarios)) do (i, scenario)
n = length(targets)
Label(fig[(n * (i - 1) + 1):(n * i), 0],
scenario_dict[scenario].title, rotation = pi / 2, fontsize = 36)
mapreduce(hcat, enumerate(targets)) do (j, target)
map(enumerate(other_gi_choices)) do (k, used_gi_choice)
row = j + (i - 1) * length(targets)
ax = Axis(fig[row, k])
# #Filter the data for fig2 panels
pred_df = _fig2_pred_filter(
prediction_df, scenario, target, latent_model, horizon;
true_gi_choice, used_gi_choice, horizon_diff)
truth_df = _fig_truth_filter(
truth_data_df, scenario, target; true_gi_choice)
# #Plot onto axes
_plot_predictions!(ax, pred_df; igps, colors, iqr_alpha)
_plot_truth!(ax, truth_df; color = data_color)
vlines!(ax, pred_df.Reference_Time |> unique, color = :black,
linestyle = :dash, label = "Reference time")
# axes
if row == 1
if k == 1
ax.title = "Underestimating mean GI"
elseif k == 2
ax.title = "Good estimation of GI"
elseif k == 3
ax.title = "Overestimating mean GI"
end
end
if row == length(targets) * length(scenarios)
ax.xlabel = "Time"
end
if k == 1
ax.ylabel = target_dict[target].title
end
ax.limits = (
(minimum(pred_df.Reference_Time) - horizon_diff,
maximum(pred_df.Reference_Time) + 1),
nothing)
ax.xticks = vcat(minimum(pred_df.Reference_Time) - horizon_diff,
pred_df.Reference_Time |> unique)
ax
end
end
end
_make_captions!(_truth_df, scenario_dict, target_dict)

truth_plotting_data = _truth_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @transform(df, :Data="Truth data") |> data
plt_truth = truth_plotting_data *
mapping(:target_times => "T", :target_values => "Process values",
row = :Scenario_Target,
col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI",
10.0 => "Good GI", 20.0 => "Overestimate GI"]),
color = :Data => AlgebraOfGraphics.scale(:color2)) *
visual(AlgebraOfGraphics.Scatter)
return plt_truth
end

function _figure_two_scenario(
analysis_df, igp, scenario_dict, target_dict; true_gi_choice,
lower_sym = :q_025, upper_sym = :q_975)
min_ref_time = minimum(analysis_df.Reference_Time)
early_df = analysis_df |>
df -> @subset(df, :Reference_Time.==min_ref_time) |>
df -> @subset(df, :IGP_Model.==igp) |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :target_times.<=min_ref_time - 7)

seqn_df = analysis_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :IGP_Model.==igp) |>
df -> @subset(df,
:Reference_Time .- :target_times.∈fill(0:6, size(df, 1)))

full_df = vcat(early_df, seqn_df)
_make_captions!(full_df, scenario_dict, target_dict)

model_plotting_data = full_df |> data

plt_model = model_plotting_data *
mapping(:target_times => "T", :q_5 => "Process values",
row = :Scenario_Target,
col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI",
10.0 => "Good GI", 20.0 => "Overestimate GI"]),
color = :Latent_Model => "Latent models") *
mapping(lower = lower_sym, upper = upper_sym) *
visual(LinesFill)

return plt_model
end

function figuretwo(truth_df, analysis_df, igp, scenario_dict,
target_dict; fig_kws = (; size = (1000, 2800)),
true_gi_choice = 10.0, gi_choices = [2.0, 10.0, 20.0])

# Perform checks on the dataframes
_dataframe_checks(truth_df, analysis_df, scenario_dict)

f_td = _figure_two_truth_data(
truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices)
f_mdl = _figure_two_scenario(
analysis_df, igp, scenario_dict, target_dict; true_gi_choice)

fg = draw(f_mdl + f_td; facet = (; linkyaxes = :none),
legend = (; orientation = :horizontal, position = :bottom),
figure = fig_kws,
axis = (; xlabel = "T", ylabel = "Process values"))
for g in fg.grid[1:3:end, :]
g.axis.limits = (nothing, target_dict["rt"].ylims)
end
for g in fg.grid[2:3:end, :]
g.axis.limits = (nothing, target_dict["Rt"].ylims)
end
for g in fg.grid[3:3:end, :]
g.axis.limits = (nothing, target_dict["log_I_t"].ylims)
end

return fg
leg = Legend(fig[length(targets) * length(scenarios) + 1, 1:2],
last(axs), "Infection generating process";
orientation = :horizontal, tellwidth = false, framevisible = false)
lab = Label(fig[length(targets) * length(scenarios) + 1, length(other_gi_choices)],
"Latent model for \n infection generating\n process: $(latent_model_dict[latent_model].title) \n True mean GI: $(true_gi_choice) days \n Horizon: $(horizon) days";
tellwidth = false,
fontsize = 18)
resize_to_layout!(fig)
fig
end
Loading
Loading