Skip to content

Commit

Permalink
fanhmm as special case of nhmm
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Jan 8, 2025
1 parent 0e6d384 commit c36a7d0
Show file tree
Hide file tree
Showing 24 changed files with 191 additions and 2,198 deletions.
12 changes: 2 additions & 10 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ rho_to_phi_field <- function(rho) {
.Call(`_seqHMM_rho_to_phi_field`, rho)
}

EM_LBFGS_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) {
.Call(`_seqHMM_EM_LBFGS_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)
}

fast_quantiles <- function(X, probs) {
.Call(`_seqHMM_fast_quantiles`, X, probs)
}
Expand Down Expand Up @@ -201,10 +197,6 @@ log_objective_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, e
.Call(`_seqHMM_log_objective_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B)
}

log_objective_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) {
.Call(`_seqHMM_log_objective_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B)
}

simulate_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B) {
.Call(`_seqHMM_simulate_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B)
}
Expand All @@ -221,8 +213,8 @@ simulate_mnhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, et
.Call(`_seqHMM_simulate_mnhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, eta_omega, X_omega, M)
}

simulate_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0) {
.Call(`_seqHMM_simulate_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, rho_A, W_A, rho_B, W_B, obs_0)
simulate_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_0) {
.Call(`_seqHMM_simulate_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_0)
}

viterbi_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) {
Expand Down
116 changes: 0 additions & 116 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -320,119 +320,3 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
}
model
}
#' @rdname bootstrap
#' @export
bootstrap_coefs.fanhmm <- function(model, nsim = 1000,
type = c("nonparametric", "parametric"),
method = "EM-DNM", append = FALSE, ...) {
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = nsim, lower = 0L),
"Argument {.arg nsim} must be a single positive integer."
)
init <- c(
setNames(model$etas, c("eta_pi", "eta_A", "eta_B")),
setNames(model$rhos, c("rho_A", "rho_B"))
)
mle <- c(model$gammas, model$rhos)
lambda <- model$estimation_results$lambda
bound <- model$estimation_results$bound
p <- progressr::progressor(along = seq_len(nsim))
original_options <- options(future.globals.maxSize = Inf)
on.exit(options(original_options))
control <- model$controls$control
control$print_level <- 0
control_mstep <- model$controls$mstep
control_mstep$print_level <- 0
if (type == "nonparametric") {
out <- future.apply::future_lapply(
seq_len(nsim), function(i) {
boot_mod <- bootstrap_model(model)
fit <- fit_fanhmm(
boot_mod$model, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, bound = bound, control = control,
control_restart = list(), control_mstep = control_mstep
)
if (fit$estimation_results$return_code >= 0) {
est <- permute_states(fit$gammas, fit$rhos, mle)
fit$gammas <- est$gammas
fit$phis <- est$phis
} else {
fit$gammas <- NULL
fit$phis <- NULL
}
p()
list(gammas = fit$gammas, idx = boot_mod$idx)
}, future.seed = TRUE
)
idx <- do.call(cbind, lapply(out, "[[", "idx"))
out <- lapply(out, "[[", "gammas")
} else {
N <- model$n_sequences
T_ <- model$sequence_lengths
M <- model$n_symbols
S <- model$n_states
formula_pi <- model$initial_formula
formula_A <- model$transition_formula
formula_B <- model$emission_formula
formula_rho_A <- model$feedback_formula
formula_rho_B <- model$autoregression_formula
d <- model$data
time <- model$time_variable
id <- model$id_variable
out <- future.apply::future_lapply(
seq_len(nsim), function(i) {
mod <- simulate_fanhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
formula_rho_B, formula_rho_A,
d, time, id, init, 0)$model
fit <- fit_fanhmm(
mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, bound = bound, control = control,
control_restart = list(), control_mstep = control_mstep
)
if (fit$estimation_results$return_code >= 0) {
fit$gammas <- permute_states(fit$gammas, gammas_mle)
} else {
fit$gammas <- NULL
}
p()
fit$gammas
}, future.seed = TRUE
)
}
boot <- list(
gamma_pi = lapply(out, "[[", "pi"),
gamma_A = lapply(out, "[[", "A"),
gamma_B = lapply(out, "[[", "B"),
phi_A = lapply(out, "[[", "phi_A"),
phi_B = lapply(out, "[[", "phi_B")
)
boot <- lapply(boot, function(x) x[lengths(x) > 0])
if (length(boot[[1]]) < nsim) {
warning_(
paste0(
"Estimation in some of the bootstrap samples failed. ",
"Returning samples from {length(boot[[1]])} successes out of {nsim} ",
"bootstrap samples."
)
)
}
if (type == "nonparametric") {
boot$idx <- idx
} else {
boot$idx <- matrix(seq_len(model$n_sequences), model$n_sequences, nsim)
}
if (append && !is.null(model$boot)) {
model$boot$gamma_pi <- c(model$boot$gamma_pi, boot$gamma_pi)
model$boot$gamma_A <- c(model$boot$gamma_A, boot$gamma_A)
model$boot$gamma_B <- c(model$boot$gamma_B, boot$gamma_B)
model$boot$phi_A <- c(model$boot$phi_A, boot$phi_A)
model$boot$phi_B <- c(model$boot$phi_B, boot$phi_B)
model$boot$idx <- cbind(model$boot$idx, idx)
} else {
model$boot <- boot
}

model
}
157 changes: 77 additions & 80 deletions R/build_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,86 @@
build_fanhmm <- function(
observations, n_states, initial_formula,
transition_formula, emission_formula, autoregression_formula,
feedback_formula, data, time, id, state_names = NULL) {
feedback_formula, data, time, id, state_names = NULL, scale = TRUE) {

y_in_data <- checkmate::test_character(observations)
stopifnot_(
y_in_data && !is.null(data[[observations]]),
"For FAN-HMM, the response variable {.arg observations} must be in the {.arg data}."
)
stopifnot_(
length(observations) == 1L,
"Currently only single-channel responses are supported for FAN-HMM.")

stopifnot_(
is.null(autoregression_formula) || inherits(autoregression_formula, "formula"),
"Argument {.arg autoregression_formula} must be {.val NULL} or a {.cls formula} object."
)
stopifnot_(
is.null(feedback_formula) || inherits(feedback_formula, "formula"),
"Argument {.arg feedback_formula} must be {.val NULL} or a {.cls formula} object."
)
stopifnot_(
!is.null(autoregression_formula) || !is.null(feedback_formula),
"Provide {.arg autoregression_formula} and/or {.arg feedback_formula} for FAN-HMM."
)
stopifnot_(
inherits(initial_formula, "formula"),
"Argument {.arg initial_formula} must be a {.cls formula} object.")
stopifnot_(
inherits(transition_formula, "formula"),
"Argument {.arg transition_formula} must be a {.cls formula} object.")
stopifnot_(
inherits(emission_formula, "formula"),
"Argument {.arg emission_formula} must be a {.cls formula} object.")

data <- .check_data(data, time, id)

if (!is.null(autoregression_formula)) {
terms_autoregression <- attr(terms(autoregression_formula), "term.labels")
if (length(terms_autoregression) == 0) {
terms_autoregression <- paste0("lag_", observations)
} else {
terms_autoregression <- paste(
paste0("lag_", observations), "+",
paste(
paste0("lag_", observations),
terms_autoregression,
sep = ":"
)
)
}
emission_formula <- update(
emission_formula,
paste("~ . + ", terms_autoregression)
)
data[[paste0("lag_", observations)]] <- group_lag(data, id, observations)
}
if (!is.null(feedback_formula)) {
terms_feedback <- attr(terms(feedback_formula), "term.labels")
if (length(terms_feedback) == 0) {
terms_feedback <- observations
} else {
terms_feedback <- paste(
observations, "+",
paste(
observations,
terms_feedback,
sep = ":"
)
)
}
transition_formula <- update(
transition_formula,
paste("~ . + ", terms_feedback)
)
}
obs_0 <- data[[observations]][data[[time]] == min(data[[time]])]
data <- data[data[[time]] > min(data[[time]]), ]
out <- create_base_nhmm(
observations, data, time, id, n_states, state_names, channel_names = NULL,
initial_formula, transition_formula, emission_formula)
initial_formula, transition_formula, emission_formula, scale = scale,
check_formulas = FALSE)
stopifnot_(
!any(out$model$observations == attr(out$model$observations, "nr")),
"FAN-HMM does not support missing values in the observations."
Expand All @@ -21,79 +92,8 @@ build_fanhmm <- function(
out$model$etas <- setNames(
create_initial_values(list(), out$model, 0), c("pi", "A", "B")
)
stopifnot_(
out$model$n_channels == 1L,
"Currently only single-channel responses are supported for FAN-HMM.")
if(is.null(feedback_formula)) {
out$model$W_A <- array(
0, c(0L, out$model$length_of_sequences - 1, out$model$n_sequences)
)
out$model$rhos$A <- create_rho_A_inits(
NULL, n_states, out$model$n_symbols, 0, 0
)
np_rho_A <- 0
} else {
stopifnot_(
inherits(feedback_formula, "formula"),
"Argument {.arg feedback_formula} must be a {.cls formula} object.")
W_A <- model_matrix_feedback_formula(
feedback_formula, data,
out$model$n_sequences,
out$model$length_of_sequences, n_states,
out$model$n_symbols, time, id,
out$model$sequence_lengths
)
x_attr <- attributes(W_A$X)
out$model$W_A <- W_A$X[, -1, , drop = FALSE]
x_attr$dim[2] <- x_attr$dim[2] - 1L
attributes(out$model$W_A) <- x_attr
out$model$rhos$A <- create_rho_A_inits(
NULL, n_states, out$model$n_symbols, nrow(out$model$W_A), 0
)
out$extras$intercept_only <- FALSE
np_rho_A <- W_A$n_pars
}
if(is.null(autoregression_formula)) {
out$model$W_B <- array(
0, c(0L, out$model$length_of_sequences - 1, out$model$n_sequences)
)
out$model$rhos$B <- create_rho_B_inits(
NULL, n_states, out$model$n_symbols, 0, 0
)
np_rho_B <- 0
} else {
stopifnot_(
inherits(autoregression_formula, "formula"),
"Argument {.arg autoregression_formula} must be a {.cls formula} object.")
W_B <- model_matrix_autoregression_formula(
autoregression_formula, data,
out$model$n_sequences,
out$model$length_of_sequences, n_states,
out$model$n_symbols, time, id,
out$model$sequence_lengths
)
x_attr <- attributes(W_B$X)
out$model$W_B <- W_B$X[, -1, , drop = FALSE]
x_attr$dim[2] <- x_attr$dim[2] - 1L
attributes(out$model$W_B) <- x_attr
out$model$rhos$B <- create_rho_B_inits(
NULL, n_states, out$model$n_symbols, nrow(out$model$W_B), 0
)
out$extras$intercept_only <- FALSE
np_rho_B <- W_B$n_pars
}
out$model$obs_0 <- as.integer(out$model$observations[, 1]) - 1L
out$model$observations <- out$model$observations[, -1]
out$model$length_of_sequences <- out$model$length_of_sequences - 1
out$model$sequence_lengths <- out$model$sequence_lengths - 1
x_attr <- attributes(out$model$X_A)
out$model$X_A <- out$model$X_A[, -1, , drop = FALSE]
x_attr$dim[2] <- x_attr$dim[2] - 1L
attributes(out$model$X_A) <- x_attr
x_attr <- attributes(out$model$X_B)
out$model$X_B <- out$model$X_B[, -1, , drop = FALSE]
x_attr$dim[2] <- x_attr$dim[2] - 1L
attributes(out$model$X_B) <- x_attr

out$model$obs_0 <- obs_0
structure(
c(
out$model,
Expand All @@ -104,14 +104,11 @@ build_fanhmm <- function(
),
class = c("fanhmm", "nhmm"),
nobs = attr(out$model$observations, "nobs"),
df = out$extras$np_pi + out$extras$np_A + out$extras$np_B + np_rho_A +
np_rho_B,
df = out$extras$np_pi + out$extras$np_A + out$extras$np_B,
type = paste0(out$extras$multichannel, "fanhmm"),
intercept_only = out$extras$intercept_only,
np_pi = out$extras$np_pi,
np_A = out$extras$np_A,
np_B = out$extras$np_B,
np_rho_A = np_rho_A,
np_rho_B = np_rho_B
np_B = out$extras$np_B
)
}
4 changes: 2 additions & 2 deletions R/build_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build_mnhmm <- function(
observations, n_states, n_clusters, initial_formula,
transition_formula, emission_formula, cluster_formula,
data, time, id, state_names = NULL, channel_names = NULL,
cluster_names = NULL) {
cluster_names = NULL, scale = TRUE) {

stopifnot_(
!missing(n_clusters) && checkmate::test_int(x = n_clusters, lower = 2L),
Expand All @@ -22,7 +22,7 @@ build_mnhmm <- function(
out <- create_base_nhmm(
observations, data, time, id, n_states, state_names, channel_names,
initial_formula, transition_formula, emission_formula, cluster_formula,
cluster_names)
cluster_names, scale = scale)
out$model$etas <- setNames(
create_initial_values(list(), out$model, 0), c("pi", "A", "B", "omega")
)
Expand Down
8 changes: 4 additions & 4 deletions R/build_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#'
#' @noRd
build_nhmm <- function(
observations, n_states, initial_formula,
transition_formula, emission_formula,
data, time, id, state_names = NULL, channel_names = NULL) {
observations, n_states, initial_formula, transition_formula,
emission_formula, data, time, id, state_names = NULL, channel_names = NULL,
scale = TRUE) {

out <- create_base_nhmm(
observations, data, time, id, n_states, state_names, channel_names,
initial_formula, transition_formula, emission_formula)
initial_formula, transition_formula, emission_formula, scale = scale)
out[c("cluster_names", "n_clusters", "X_omega")] <- NULL
out$model$etas <- setNames(
create_initial_values(list(), out$model, 0), c("pi", "A", "B")
Expand Down
Loading

0 comments on commit c36a7d0

Please sign in to comment.