Skip to content

Commit

Permalink
Issue 3: Not using DT. (#4)
Browse files Browse the repository at this point in the history
Former-commit-id: 1a625cd
Former-commit-id: 68d2d0232db3dc09efaa429d5641b32b4aee1f2f
  • Loading branch information
parksw3 authored Apr 4, 2024
1 parent 9d1dd53 commit adc6d5c
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 216 deletions.
56 changes: 27 additions & 29 deletions R/fitting-and-postprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
sample_model <- function(model, data, scenario = data.table::data.table(id = 1),
diagnostics = TRUE, ...) {

out <- scenario |>
copy()
out <- data.table::copy(scenario)

# Setup failure tolerant model fitting
fit_model <- function(model, data, ...) {
Expand All @@ -16,12 +15,10 @@ sample_model <- function(model, data, scenario = data.table::data.table(id = 1),
fit <- safe_fit_model(model, data, ...)

if (!is.null(fit$error)) {
out <- out |>
DT(, error := list(fit$error[[1]]))
out[, error := list(fit$error[[1]])]
diagnostics <- FALSE
}else {
out <- out |>
DT(, fit := list(fit$result))
out[, fit := list(fit$result)]
fit <- fit$result
}

Expand Down Expand Up @@ -57,8 +54,7 @@ sample_epinowcast_model <- function(
diagnostics = TRUE, ...
) {

out <- scenario |>
copy()
out <- data.table::copy(scenario)

# Setup failure tolerant model fitting
fit_model <- function(model, data, ...) {
Expand All @@ -72,12 +68,10 @@ sample_epinowcast_model <- function(
fit <- safe_fit_model(model, data, ...)

if (!is.null(fit$error)) {
out <- out |>
DT(, error := list(fit$error[[1]]))
out[, error := list(fit$error[[1]])]
diagnostics <- FALSE
}else {
out <- out |>
DT(, fit := list(fit$result))
out[, fit := list(fit$result)]
fit <- fit$result
}

Expand Down Expand Up @@ -123,11 +117,12 @@ sample_epinowcast_model <- function(
#' Add natural scale summary parameters for a lognormal distribution
#' @export
add_natural_scale_mean_sd <- function(dt) {
nat_dt <- dt |>
data.table::DT(, mean := exp(meanlog + sdlog ^ 2 / 2)) |>
data.table::DT(,
sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1)
)
nat_dt <- data.table::copy(dt)

nat_dt <- nat_dt[,mean := exp(meanlog + sdlog ^ 2 / 2)]

nat_dt <- nat_dt[,sd := exp(meanlog + (1 / 2) * sdlog ^ 2) * sqrt(exp(sdlog ^ 2) - 1)]

return(nat_dt[])
}

Expand Down Expand Up @@ -186,8 +181,7 @@ extract_epinowcast_draws <- function(
)
}

draws <- draws |>
data.table::setDT()
draws <- data.table::setDT(draws)

data.table::setnames(
draws, c("refp_mean_int[1]", "refp_sd_int[1]"), c("meanlog", "sdlog"),
Expand All @@ -207,10 +201,11 @@ extract_epinowcast_draws <- function(
#' Primary event bias correction
#' @export
primary_censoring_bias_correction <- function(draws) {
draws <- data.table::copy(draws) |>
DT(, mean := mean - runif(.N, min = 0, max = 1)) |>
DT(, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))) |>
DT(, sdlog := sqrt(log(1 + (sd^2 / mean^2))))
draws <- data.table::copy(draws)
draws[, mean := mean - runif(.N, min = 0, max = 1)]
draws[, meanlog := log(mean^2 / sqrt(sd^2 + mean^2))]
draw[, sdlog := sqrt(log(1 + (sd^2 / mean^2)))]

return(draws[])
}

Expand All @@ -234,7 +229,8 @@ make_relative_to_truth <- function(draws, secondary_dist, by = "parameter") {
by = by
)

draws <- draws[, rel_value := value / true_value]
draws[, rel_value := value / true_value]

return(draws[])
}

Expand Down Expand Up @@ -289,9 +285,11 @@ summarise_variable <- function(draws, variable, sf = 6, by = c()) {
if (missing(variable)) {
stop("variable must be specified")
}
summarised_draws <- draws |>
copy() |>
DT(, value := variable, env = list(variable = variable)) |>
summarise_draws(sf = sf, by = by)
summarised_draws <- data.table::copy(draws)

summarised_draws[, value := variable, env = list(variable = variable)]

summarised_draws <- summarise_draws(summarised_draws, sf = sf, by = by)

return(summarised_draws[])
}
}
69 changes: 30 additions & 39 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ naive_delay <- function(formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data,
filtered_naive_delay <- function(
formula = brms::bf(delay_daily ~ 1, sigma ~ 1), data, fn = brms::brm,
family = "lognormal", truncation = 10, ...) {
data <- data |>
data.table::as.data.table() |>
## NEED TO FILTER BASED ON PTIME
DT(ptime_daily <= (obs_at - truncation))

data <- data.table::as.data.table(data)
## NEED TO FILTER BASED ON PTIME
data <- data[ptime_daily <= (obs_at - truncation)]

data <- drop_zero(data)

fn(
Expand Down Expand Up @@ -85,12 +84,11 @@ latent_censoring_adjusted_delay <- function(

stanvars_all <- stanvars_functions + stanvars_parameters + stanvars_prior

data <- data |>
data.table::as.data.table() |>
DT(, id := 1:.N) |>
DT(, pwindow_upr := ptime_upr - ptime_lwr) |>
DT(, swindow_upr := stime_upr - stime_lwr) |>
DT(, delay_central := stime_lwr - ptime_lwr)
data <- data.table::as.data.table(data)
data[, id := 1:.N]
data[, pwindow_upr := ptime_upr - ptime_lwr]
data[, swindow_upr := stime_upr - stime_lwr]
data[, delay_central := stime_lwr - ptime_lwr]

if (nrow(data) > 1) {
data <- data[, id := as.factor(id)]
Expand All @@ -111,9 +109,8 @@ filtered_censoring_adjusted_delay <- function(
delay_lwr | cens(censored, delay_upr) ~ 1, sigma ~ 1
), data, fn = brms::brm, family = "lognormal", truncation = 10, ...) {

data <- data |>
data.table::as.data.table() |>
DT(ptime_daily <= (obs_at - truncation))
data <- data.table::as.data.table(data)
data <- data[ptime_daily <= (obs_at - truncation)]

data <- pad_zero(data)

Expand Down Expand Up @@ -201,22 +198,18 @@ latent_truncation_censoring_adjusted_delay <- function(
...
) {

data <- data |>
data.table::as.data.table() |>
DT(, id := 1:.N) |>
DT(, obs_t := obs_at - ptime_lwr) |>
DT(, pwindow_upr := ifelse(
stime_lwr < ptime_upr, ## if overlap
stime_upr - ptime_lwr,
ptime_upr - ptime_lwr
)
) |>
DT(,
woverlap := as.numeric(stime_lwr < ptime_upr)
) |>
DT(, swindow_upr := stime_upr - stime_lwr) |>
DT(, delay_central := stime_lwr - ptime_lwr) |>
DT(, row_id := 1:.N)
data <- data.table::as.data.table(data)
data[, id := 1:.N]
data[, obs_t := obs_at - ptime_lwr]
data[, pwindow_upr := ifelse(
stime_lwr < ptime_upr, ## if overlap
stime_upr - ptime_lwr,
ptime_upr - ptime_lwr
)]
data[, woverlap := as.numeric(stime_lwr < ptime_upr)]
data[, swindow_upr := stime_upr - stime_lwr]
data[, delay_central := stime_lwr - ptime_lwr]
data[, row_id := 1:.N]

if (nrow(data) > 1) {
data <- data[, id := as.factor(id)]
Expand Down Expand Up @@ -323,9 +316,8 @@ dynamical_censoring_adjusted_delay <- function(
)
}
cols <- colnames(data)[map_lgl(data, is.integer)]
data <- data |>
data.table::as.data.table() |>
DT(, (cols) := lapply(.SD, as.double), .SDcols = cols)
data <- data.table::as.data.table(data)
data[, (cols) := lapply(.SD, as.double), .SDcols = cols]

data <- drop_zero(data)
## need to do this because lognormal doesn't like zero
Expand Down Expand Up @@ -433,12 +425,11 @@ epinowcast_delay <- function(formula = ~ 1, data, by = c(),
"epinowcast is not installed. Please install it to use this function"
)
}
data_as_counts <- data |>
data.table::as.data.table() |>
DT(, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)) |>
DT(order(ptime_daily, stime_daily)) |>
DT(, reference_date := as.Date("2000-01-01") + ptime_daily) |>
DT(, report_date := as.Date("2000-01-01") + stime_daily)
data_as_counts <- data.table::as.data.table(data)
data_as_counts <- data_as_counts[, .(new_confirm = .N), by = c("ptime_daily", "stime_daily", by)]
data_as_counts <- data_as_counts[order(ptime_daily, stime_daily)]
data_as_counts[, reference_date := as.Date("2000-01-01") + ptime_daily]
data_as_counts[, report_date := as.Date("2000-01-01") + stime_daily]

# Actual largest observerable delay
preprocess_delay <- min(
Expand Down
101 changes: 47 additions & 54 deletions R/observe.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
#' Observation process for primary and secondary events
#' @export
observe_process <- function(linelist) {
clinelist <- linelist |>
data.table::copy() |>
DT(, ptime_daily := floor(ptime)) |>
DT(, ptime_lwr := ptime_daily) |>
DT(, ptime_upr := ptime_daily + 1) |>
# How the second event would be recorded in the data
DT(, stime_daily := floor(stime)) |>
DT(, stime_lwr := stime_daily) |>
DT(, stime_upr := stime_daily + 1) |>
# How would we observe the delay distribution
# previously: delay_daily=floor(delay)
DT(, delay_daily := stime_daily - ptime_daily) |>
DT(, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))) |>
DT(, delay_upr := delay_daily + 1) |>
# We assume observation time is the ceiling of the maximum delay
DT(, obs_at := stime |>
max() |>
ceiling()
)
clinelist <- data.table::copy(linelist)
clinelist[, ptime_daily := floor(ptime)]
clinelist[, ptime_lwr := ptime_daily]
clinelist[, ptime_upr := ptime_daily + 1]
# How the second event would be recorded in the data
clinelist[, stime_daily := floor(stime)]
clinelist[, stime_lwr := stime_daily]
clinelist[, stime_upr := stime_daily + 1]
# How would we observe the delay distribution
# previously: delay_daily=floor(delay)
clinelist[, delay_daily := stime_daily - ptime_daily]
clinelist[, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))]
clinelist[, delay_upr := delay_daily + 1]
# We assume observation time is the ceiling of the maximum delay
clinelist[, obs_at := stime |>
max() |>
ceiling()]

return(clinelist)
}

#' Filter observations based on a observation time of secondary events
#' @export
filter_obs_by_obs_time <- function(linelist, obs_time) {
truncated_linelist <- linelist |>
data.table::copy() |>
# Update observation time by when we are looking
DT(, obs_at := obs_time) |>
DT(, obs_time := obs_time - ptime) |>
# Assuming truncation at the beginning of the censoring window
DT(,
censored_obs_time := obs_at - ptime_lwr
) |>
DT(, censored := "interval") |>
DT(stime_upr <= obs_at)
truncated_linelist <- data.table::copy(linelist)
truncated_linelist[, obs_at := obs_time]
truncated_linelist[, obs_time := obs_time - ptime]
truncated_linelist[, censored_obs_time := obs_at - ptime_lwr]
truncated_linelist[, censored := "interval"]
truncated_linelist <- truncated_linelist[stime_upr <= obs_at]

return(truncated_linelist)
}

Expand All @@ -47,49 +42,47 @@ filter_obs_by_ptime <- function(linelist, obs_time,
obs_at <- match.arg(obs_at)

pfilt_t <- obs_time
truncated_linelist <- linelist |>
data.table::copy() |>
DT(, censored := "interval") |>
DT(ptime_upr <= pfilt_t)
truncated_linelist <- data.table::copy(linelist)

truncated_linelist[, censored := "interval"]
truncated_linelist <- truncated_linelist[ptime_upr <= pfilt_t]

if (obs_at == "obs_secondary") {
truncated_linelist <- truncated_linelist |>
# Update observation time to be the same as the maximum secondary time
DT(, obs_at := stime_upr)
# Update observation time to be the same as the maximum secondary time
truncated_linelist[, obs_at := stime_upr]
} else if (obs_at == "max_secondary") {
truncated_linelist <- truncated_linelist |>
DT(, obs_at := stime_upr |> max() |> ceiling())
truncated_linelist[, obs_at := stime_upr |> max() |> ceiling()]
}

# make observation time as specified
truncated_linelist <- truncated_linelist |>
DT(, obs_time := obs_at - ptime) |>
# Assuming truncation at the beginning of the censoring window
DT(, censored_obs_time := obs_at - ptime_lwr)
truncated_linelist[, obs_time := obs_at - ptime]
# Assuming truncation at the beginning of the censoring window
truncated_linelist[, censored_obs_time := obs_at - ptime_lwr]

# set observation time to artifial observation time
if (obs_at == "obs_secondary") {
truncated_linelist <- truncated_linelist |>
DT(, obs_at := pfilt_t)
truncated_linelist[, obs_at := pfilt_t]
}
return(truncated_linelist)
}

#' Pad zero observations as unstable in a lognormal distribution
#' @export
pad_zero <- function(data, pad = 1e-3) {
data <- data |>
data.table::copy() |>
# Need upper bound to be greater than lower bound
DT(censored_obs_time == 0, censored_obs_time := 2 * pad) |>
DT(delay_lwr == 0, delay_lwr := pad) |>
DT(delay_daily == 0, delay_daily := pad)
data <- data.table::copy(data)
# Need upper bound to be greater than lower bound
data[censored_obs_time == 0, censored_obs_time := 2 * pad]
data[delay_lwr == 0, delay_lwr := pad]
data[delay_daily == 0, delay_daily := pad]

return(data)
}

#' Drop zero observations as unstable in a lognormal distribution
#' @export
drop_zero <- function(data) {
data <- data |>
data.table::copy() |>
DT(delay_daily != 0)
data <- data.table::copy(data)
data[delay_daily != 0]

return(data)
}
Loading

0 comments on commit adc6d5c

Please sign in to comment.