Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor improvements to helper/checker functions #192

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' string corresponding to the R distribution function (e.g., "rpois" for
#' Poisson.
#' @keywords internal
check_offspring_func_valid <- function(roffspring_name) {
.check_offspring_func_valid <- function(roffspring_name) {
checkmate::assert(
exists(roffspring_name) ||
checkmate::assert_function(get(roffspring_name)),
Expand All @@ -18,7 +18,7 @@ check_offspring_func_valid <- function(roffspring_name) {
#' @inheritParams simulate_chains
#'
#' @keywords internal
check_generation_time_valid <- function(generation_time) {
.check_generation_time_valid <- function(generation_time) {
checkmate::assert_function(generation_time, nargs = 1)
x <- generation_time(10)
checkmate::assert_numeric(x, len = 10)
Expand Down
4 changes: 2 additions & 2 deletions R/epichains.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ epichains_tree <- function(tree_df,
checkmate::assert_integerish(index_cases, null.ok = TRUE)
checkmate::assert_character(statistic, null.ok = TRUE)
checkmate::assert_string(offspring_dist)
check_offspring_func_valid(paste0("r", offspring_dist))
.check_offspring_func_valid(paste0("r", offspring_dist))
checkmate::assert_logical(track_pop)
checkmate::assert_number(stat_max, null.ok = TRUE)

Expand Down Expand Up @@ -143,7 +143,7 @@ epichains_summary <- function(chains_summary,
checkmate::assert_integerish(index_cases, null.ok = TRUE)
checkmate::assert_character(statistic)
checkmate::assert_string(offspring_dist)
check_offspring_func_valid(paste0("r", offspring_dist))
.check_offspring_func_valid(paste0("r", offspring_dist))
checkmate::assert_number(stat_max, null.ok = TRUE)

# Create <epichains_summary> object
Expand Down
35 changes: 18 additions & 17 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
#' @param n_offspring A vector of offspring per chain.
#' @return A vector of chain statistics (size/length).
#' @keywords internal
update_chain_stat <- function(stat_type, stat_latest, n_offspring) {
if (stat_type == "size") {
stat_latest <- stat_latest + n_offspring
} else if (stat_type == "length") {
stat_latest <- stat_latest + pmin(1, n_offspring)
}

return(stat_latest)
.update_chain_stat <- function(stat_type, stat_latest, n_offspring) {
return(
switch(
stat_type,
size = stat_latest + n_offspring,
length = stat_latest + pmin(1, n_offspring)
)
)
}

#' Return a function for calculating chain statistics
Expand All @@ -21,13 +21,14 @@ update_chain_stat <- function(stat_type, stat_latest, n_offspring) {
#'
#' @return a function for calculating chain statistics
#' @keywords internal
get_statistic_func <- function(chain_statistic) {
func <- if (chain_statistic == "size") {
rbinom_size
} else if (chain_statistic == "length") {
rgen_length
}
return(func)
.get_statistic_func <- function(chain_statistic) {
return(
switch(
chain_statistic,
size = rbinom_size,
length = rgen_length
)
)
}

#' Construct name of analytical function for estimating loglikelihood of
Expand All @@ -37,7 +38,7 @@ get_statistic_func <- function(chain_statistic) {
#'
#' @return an analytical offspring likelihood function
#' @keywords internal
construct_offspring_ll_name <- function(offspring_dist, chain_statistic) {
.construct_offspring_ll_name <- function(offspring_dist, chain_statistic) {
ll_name <- paste(offspring_dist, chain_statistic, "ll", sep = "_")
return(ll_name)
}
Expand All @@ -49,7 +50,7 @@ construct_offspring_ll_name <- function(offspring_dist, chain_statistic) {
#'
#' @return numeric; adjusted next generation offspring vector
#' @keywords internal
adjust_next_gen <- function(next_gen, susc_pop) {
.adjust_next_gen <- function(next_gen, susc_pop) {
## create hypothetical next generation individuals to sample from
next_gen_pop <- rep(
seq_along(next_gen),
Expand Down
4 changes: 2 additions & 2 deletions R/likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ likelihood <- function(chains, statistic = c("size", "length"), offspring_dist,
stop("'nsim_obs' must be specified if 'obs_prob' is < 1")
}

statistic_func <- get_statistic_func(statistic)
statistic_func <- .get_statistic_func(statistic)

stat_rep_list <- replicate(nsim_obs, pmin(
statistic_func(
Expand Down Expand Up @@ -109,7 +109,7 @@ likelihood <- function(chains, statistic = c("size", "length"), offspring_dist,

## get log-likelihood function as given by offspring_dist and statistic
likelihoods <- vector(mode = "numeric")
ll_func <- construct_offspring_ll_name(offspring_dist, statistic)
ll_func <- .construct_offspring_ll_name(offspring_dist, statistic)
pars <- as.list(unlist(list(...))) ## converts vectors to lists

## calculate log-likelihoods
Expand Down
12 changes: 6 additions & 6 deletions R/simulate.r
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ simulate_chains <- function(index_cases,
"r",
offspring_dist
)
check_offspring_func_valid(roffspring_name)
.check_offspring_func_valid(roffspring_name)
checkmate::assert(
is.infinite(stat_max) ||
checkmate::assert_integerish(stat_max, lower = 0)
Expand All @@ -164,7 +164,7 @@ simulate_chains <- function(index_cases,
lower = 0, upper = 1
)
if (!missing(generation_time)) {
check_generation_time_valid(generation_time)
.check_generation_time_valid(generation_time)
} else if (!missing(tf)) {
stop("If `tf` is specified, `generation_time` must be specified too.")
}
Expand Down Expand Up @@ -237,7 +237,7 @@ simulate_chains <- function(index_cases,
# Adjust next_gen if the number of offspring is greater than the
# susceptible population.
if (sum(next_gen) > susc_pop) {
next_gen <- adjust_next_gen(
next_gen <- .adjust_next_gen(
next_gen = next_gen,
susc_pop = susc_pop
)
Expand All @@ -250,7 +250,7 @@ simulate_chains <- function(index_cases,
# assign offspring sum to indices still being simulated
n_offspring[sim] <- tapply(next_gen, parent_ids, sum)
# track size/length
stat_track <- update_chain_stat(
stat_track <- .update_chain_stat(
stat_type = statistic,
stat_latest = stat_track,
n_offspring = n_offspring
Expand Down Expand Up @@ -363,7 +363,7 @@ simulate_summary <- function(index_cases, statistic = c("size", "length"),

# check that offspring function exists in base R
roffspring_name <- paste0("r", offspring_dist)
check_offspring_func_valid(roffspring_name)
.check_offspring_func_valid(roffspring_name)

checkmate::assert_number(
stat_max, lower = 0
Expand Down Expand Up @@ -400,7 +400,7 @@ simulate_summary <- function(index_cases, statistic = c("size", "length"),
n_offspring[sim] <- tapply(next_gen, indices, sum)

# track size/length
stat_track <- update_chain_stat(
stat_track <- .update_chain_stat(
stat_type = statistic,
stat_latest = stat_track,
n_offspring = n_offspring
Expand Down
6 changes: 3 additions & 3 deletions man/adjust_next_gen.Rd → man/dot-adjust_next_gen.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/get_statistic_func.Rd → man/dot-get_statistic_func.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/update_chain_stat.Rd → man/dot-update_chain_stat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
test_that("Checks work", {
expect_error(
check_offspring_func_valid("rrpois"),
.check_offspring_func_valid("rrpois"),
"not found"
)
expect_error(
check_generation_time_valid("a"),
.check_generation_time_valid("a"),
"Must be a function"
)
expect_error(
check_generation_time_valid(function(x) rep("a", 10)),
.check_generation_time_valid(function(x) rep("a", 10)),
"numeric"
)
expect_error(
check_generation_time_valid(function(x) 3),
.check_generation_time_valid(function(x) 3),
"Must have length"
)
})
10 changes: 5 additions & 5 deletions tests/testthat/test-helpers.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("construct_offspring_ll_name works correctly", {
expect_identical(
construct_offspring_ll_name(
.construct_offspring_ll_name(
offspring_dist = "pois",
chain_statistic = "size"
),
Expand All @@ -12,15 +12,15 @@ test_that("update_chain_stat works correctly", {
stat_latest <- 1
n_offspring <- 2
expect_identical(
update_chain_stat(
.update_chain_stat(
stat_type = "size",
stat_latest = stat_latest,
n_offspring = n_offspring
),
stat_latest + n_offspring
)
expect_identical(
update_chain_stat(
.update_chain_stat(
stat_type = "length",
stat_latest = stat_latest,
n_offspring = n_offspring
Expand All @@ -31,11 +31,11 @@ test_that("update_chain_stat works correctly", {

test_that("get_statistic_func works correctly", {
expect_identical(
get_statistic_func(chain_statistic = "size"),
.get_statistic_func(chain_statistic = "size"),
rbinom_size
)
expect_identical(
get_statistic_func(chain_statistic = "length"),
.get_statistic_func(chain_statistic = "length"),
rgen_length
)
})
Loading