Skip to content

Commit

Permalink
default bound to inf, fix status code, fix gradients with zero expect…
Browse files Browse the repository at this point in the history
…ed counts
  • Loading branch information
helske committed Dec 12, 2024
1 parent 5963ea3 commit 931b631
Show file tree
Hide file tree
Showing 26 changed files with 122 additions and 144 deletions.
6 changes: 3 additions & 3 deletions R/dnm_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
)
if (fit$status == -1 && need_grad) {
grad_norm <- sqrt(sum(objectivef(fit$solution)$gradient^2))
if (grad_norm < 1e-6) fit$status <- 6
if (grad_norm < 1e-6) fit$status <- 7
}
p()
fit
Expand Down Expand Up @@ -169,9 +169,9 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control
)
if (ou$status == -1 && need_grad) {
if (out$status == -1 && need_grad) {
grad_norm <- sqrt(sum(objectivef(out$solution)$gradient^2))
if (grad_norm < 1e-6) out$status <- 6
if (grad_norm < 1e-6) out$status <- 7
}
if (out$status < 0) {
warning_(
Expand Down
2 changes: 1 addition & 1 deletion R/dnm_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
)
if (fit$status == -1 && need_grad) {
grad_norm <- sqrt(sum(objectivef(fit$solution)$gradient^2))
if (grad_norm < 1e-6) fit$status <- 6
if (grad_norm < 1e-6) fit$status <- 7
}
p()
fit
Expand Down
2 changes: 1 addition & 1 deletion R/em_dnm_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda,
)
if (fit$status == -1 && need_grad) {
grad_norm <- sqrt(sum(objectivef(fit$solution)$gradient^2))
if (grad_norm < 1e-6) fit$status <- 6
if (grad_norm < 1e-6) fit$status <- 7
}
p()
fit
Expand Down
2 changes: 1 addition & 1 deletion R/em_dnm_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda,
)
if (fit$status == -1 && need_grad) {
grad_norm <- sqrt(sum(objectivef(fit$solution)$gradient^2))
if (grad_norm < 1e-6) fit$status <- 6
if (grad_norm < 1e-6) fit$status <- 7
}
p()
fit
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, lambda = 0, method = "EM-DNM", bound = 50,
restarts = 0L, lambda = 0, method = "EM-DNM", bound = Inf,
control_restart = list(), control_mstep = list(), store_data = TRUE, ...) {

call <- match.call()
Expand Down
8 changes: 4 additions & 4 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@
#' direct maximization of the log-likelihood, by default using L-BFGS. Option
#' `"EM-DNM"` (the default) runs first a maximum of 10 iterations of EM and
#' then switches to L-BFGS (but other algorithms of NLopt can be used).
#' @param bound Positive value defining the hard bounds for the working
#' parameters \eqn{\eta}, which are used to avoid extreme probabilities and
#' @param bound Positive value defining the hard lower and upper bounds for the
#' working parameters \eqn{\eta}, which are used to avoid extreme probabilities and
#' corresponding numerical issues especially in the M-step of EM algorithm.
#' Default is 50, i.e., \eqn{-50<\eta<50}. Note that he bounds are not enforced
#' Default is `Inf´, i.e., no bounds. Note that he bounds are not enforced
#' for M-step in intercept-only case with `lambda = 0`.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
#' is stored to the model object. For large datasets, this can be set to
Expand Down Expand Up @@ -121,7 +121,7 @@ estimate_nhmm <- function(
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, inits = "random", init_sd = 2, restarts = 0L,
lambda = 0, method = "EM-DNM", bound = 50, control_restart = list(),
lambda = 0, method = "EM-DNM", bound = Inf, control_restart = list(),
control_mstep = list(), store_data = TRUE, ...) {

call <- match.call()
Expand Down
2 changes: 1 addition & 1 deletion R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
control_restart
)
stopifnot_(
identical(control$algorithm, control_restart$algorithm),
restarts == 0 || identical(control$algorithm, control_restart$algorithm),
c("Cannot mix different algorithms for multistart and final optimization.",
"Found algorithm {.val {control$algorithm}} for final optimization and
{.val {control_restart$algorithm}} for multistart.")
Expand Down
2 changes: 1 addition & 1 deletion R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control_restart
)
stopifnot_(
identical(control$algorithm, control_restart$algorithm),
restarts == 0 || identical(control$algorithm, control_restart$algorithm),
c("Cannot mix different algorithms for multistart and final optimization.",
"Found algorithm {.val {control$algorithm}} for final optimization and
{.val {control_restart$algorithm}} for multistart.")
Expand Down
2 changes: 1 addition & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ return_msg <- function(code) {
if (code == 7) {
msg <- paste0(
"NLopt terminated with generic error code -1. ",
"Gradient norm was less than 1e-6 likely converged successfully."
"Gradient norm was less than 1e-6, so likely converged successfully."
)
}
paste0(x, msg)
Expand Down
9 changes: 4 additions & 5 deletions man/estimate_mnhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions man/estimate_nhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 25 additions & 41 deletions src/mnhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,19 @@ double mnhmm_base::objective_omega(const arma::vec& x, arma::vec& grad) {
grad.zeros();
arma::mat tQd = Qd.t();
arma::uvec idx(D);
arma::vec diff(D);
for (arma::uword i = 0; i < N; i++) {
if (!icpt_only_omega || i == 0) {
update_omega(i);
}
const arma::vec& counts = E_omega.col(i);
idx = arma::find(counts);
if (idx.n_elem > 0) {
double val = arma::dot(counts.rows(idx), log_omega.rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;
diff.zeros();
diff.rows(idx) = counts(idx) - omega.rows(idx);
grad -= arma::vectorise(tQd * diff * X_omega.col(i).t());
double val = arma::dot(counts.rows(idx), log_omega.rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;
grad -= arma::vectorise(tQd * (counts - omega) * X_omega.col(i).t());
}
grad += lambda * x;

Expand Down Expand Up @@ -104,24 +99,19 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
grad.zeros();
arma::mat tQs = Qs.t();
arma::uvec idx(S);
arma::vec diff(S);
for (arma::uword i = 0; i < N; i++) {
if (!icpt_only_pi || i == 0) {
update_pi(i, current_d);
}
const arma::vec& counts = E_Pi(current_d).col(i);
const arma::vec& counts = E_pi(current_d).col(i);
idx = arma::find(counts);
if (idx.n_elem > 0) {
double val = arma::dot(counts.rows(idx), log_pi(current_d).rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;
diff.zeros();
diff.rows(idx) = counts.rows(idx) - pi(current_d).rows(idx);
grad -= arma::vectorise(tQs * diff * X_pi.col(i).t());
double val = arma::dot(counts.rows(idx), log_pi(current_d).rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;
grad -= arma::vectorise(tQs * (counts - pi(current_d)) * X_pi.col(i).t());
}
grad += lambda * x;

Expand All @@ -136,7 +126,7 @@ void mnhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs,
// Use closed form solution
if (icpt_only_pi && lambda < 1e-12) {
for (arma::uword d = 0; d < D; d++) {
eta_pi(d) = Qs.t() * log(arma::sum(E_Pi(d), 1) + arma::datum::eps);
eta_pi(d) = Qs.t() * log(arma::sum(E_pi(d), 1) + arma::datum::eps);
if (!eta_pi(d).is_finite()) {
mstep_return_code = -100;
return;
Expand Down Expand Up @@ -199,7 +189,6 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
arma::vec log_A1(S);
grad.zeros();
arma::uvec idx(S);
arma::vec diff(S);
if (!iv_A && !tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(0).col(0));
log_A1 = log(A1);
Expand All @@ -213,23 +202,18 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
for (arma::uword t = 0; t < (Ti(i) - 1); t++) {
const arma::vec& counts = E_A(current_s, current_d).slice(t).col(i);
idx = arma::find(counts);
if (idx.n_elem > 0) {
double sum_ea = arma::accu(counts.rows(idx));
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
}
double val = arma::dot(counts.rows(idx), log_A1.rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;

diff.zeros();
diff.rows(idx) = counts.rows(idx) - sum_ea * A1.rows(idx);
grad -= arma::vectorise(tQs * diff * X_A.slice(i).col(t).t());
double sum_ea = arma::accu(counts.rows(idx));
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
}
double val = arma::dot(counts.rows(idx), log_A1.rows(idx));
if (!std::isfinite(val)) {
grad.zeros();
return maxval;
}
value -= val;
grad -= arma::vectorise(tQs * (counts - sum_ea * A1) * X_A.slice(i).col(t).t());
}
}
grad += lambda * x;
Expand Down
16 changes: 11 additions & 5 deletions src/mnhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct mnhmm_base {
arma::cube log_py;
// excepted counts for EM algorithm
arma::mat E_omega;
arma::field<arma::mat> E_Pi;
arma::field<arma::mat> E_pi;
arma::field<arma::cube> E_A;
arma::uword current_s;
arma::uword current_d;
Expand Down Expand Up @@ -77,7 +77,7 @@ struct mnhmm_base {
const arma::field<arma::cube>& eta_A_,
const arma::uword n_obs_ = 0,
const double lambda_ = 0,
double maxval_ = 1e6)
double maxval_ = arma::datum::inf)
: S(S_),
D(D_),
X_omega(X_d_),
Expand Down Expand Up @@ -115,7 +115,7 @@ struct mnhmm_base {
log_A(D),
log_py(S, T, D),
E_omega(D, N),
E_Pi(D),
E_pi(D),
E_A(S, D),
current_s(0),
current_d(0),
Expand All @@ -127,7 +127,7 @@ struct mnhmm_base {
log_pi(d) = arma::vec(S);
A(d) = arma::cube(S, S, T);
log_A(d) = arma::cube(S, S, T);
E_Pi(d) = arma::mat(S, N);
E_pi(d) = arma::mat(S, N);
for (arma::uword s = 0; s < S; s++) {
E_A(s, d) = arma::cube(S, N, T);
}
Expand Down Expand Up @@ -234,12 +234,16 @@ struct mnhmm_base {
void estep_omega(const arma::uword i, const arma::vec ll_i,
const double ll) {
E_omega.col(i) = arma::exp(ll_i - ll);
// set minuscule values to zero in order to avoid numerical issues
E_omega.col(i).clean(std::numeric_limits<double>::min());
}

void estep_pi(const arma::uword i, const arma::uword d,
const arma::vec& log_alpha,
const arma::vec& log_beta, const double ll) {
E_Pi(d).col(i) = arma::exp(log_alpha + log_beta - ll);
E_pi(d).col(i) = arma::exp(log_alpha + log_beta - ll);
// set minuscule values to zero in order to avoid numerical issues
E_pi(d).col(i).clean(std::numeric_limits<double>::min());
}

void estep_A(const arma::uword i, const arma::uword d,
Expand All @@ -252,6 +256,8 @@ struct mnhmm_base {
log_beta(j, t + 1) + log_py(j, t + 1, d) - ll);
}
}
// set minuscule values to zero in order to avoid numerical issues
E_A(k, d).col(i).clean(std::numeric_limits<double>::min());
}
}
void mstep_omega(const double ftol_abs, const double ftol_rel,
Expand Down
2 changes: 1 addition & 1 deletion src/mnhmm_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct mnhmm_mc : public mnhmm_base {
for (arma::uword t = 0; t < Ti(i); t++) { // time
double pp = exp(log_alpha(k, t) + log_beta(k, t) - ll);
for (arma::uword c = 0; c < C; c++) { // channel
if (obs(c, t, i) < M(c)) {
if (obs(c, t, i) < M(c) && pp > std::numeric_limits<double>::min()) {
E_B(c, d)(t, i, k) = pp;
} else {
E_B(c, d)(t, i, k) = 0.0;
Expand Down
1 change: 1 addition & 0 deletions src/mnhmm_sc.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ struct mnhmm_sc : public mnhmm_base {
}
}
}
E_B(d).clean(std::numeric_limits<double>::min());
}
void mstep_B(const double ftol_abs, const double ftol_rel,
const double xtol_abs, const double xtol_rel,
Expand Down
Loading

0 comments on commit 931b631

Please sign in to comment.