diff --git a/R/sampler-nested-adaptive.R b/R/sampler-nested-adaptive.R index 7f919921..9f19417a 100644 --- a/R/sampler-nested-adaptive.R +++ b/R/sampler-nested-adaptive.R @@ -159,7 +159,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv, list(base = autocorrelation_base, groups = autocorrelation_groups) internal$vcv <- list(base = vcv_base, groups = vcv_groups) proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups) - internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups) + internal$proposal <- nested_proposal_adaptive(proposal_vcv, + model$parameter_groups) internal$history_pars <- numeric() internal$included <- integer() @@ -311,7 +312,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv, ## Update proposal proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups) - internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups) + internal$proposal <- nested_proposal_adaptive(proposal_vcv, + model$parameter_groups) state } @@ -425,3 +427,66 @@ check_nested_adaptive <- function(x, n_groups, has_base, null_allowed = FALSE, ret } + + +## TODO: this is a simpler version of nested_proposal that does not +## cope with boundaries etc - that's being looked at in #46 for now. +## nocov start +nested_proposal_adaptive <- function(vcv, parameter_groups, call = NULL) { + i_base <- parameter_groups == 0 + n_base <- sum(i_base) + n_groups <- max(parameter_groups) + i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i)) + if (NROW(vcv$base) != n_base) { + cli::cli_abort( + c("Incompatible number of base parameters in your model and sampler", + i = paste("Your model has {n_base} base parameters, but 'vcv$base'", + "implies {NROW(vcv$base)} parameters")), + call = call) + } + if (length(vcv$groups) != n_groups) { + cli::cli_abort( + c("Incompatible number of parameter groups in your model and sampler", + i = paste("Your model has {n_groups} parameter groups, but", + "'vcv$groups' has {length(vcv$groups)} groups")), + call = call) + } + n_pars_by_group <- lengths(i_group) + n_pars_by_group_vcv <- vnapply(vcv$groups, nrow) + err <- n_pars_by_group_vcv != n_pars_by_group + if (any(err)) { + detail <- sprintf( + "Group %d has %d parameters but 'vcv$groups[[%d]]' has %d", + which(err), n_pars_by_group[err], + which(err), n_pars_by_group_vcv[err]) + cli::cli_abort( + c("Incompatible number of parameters within parameter group", + set_names(detail, "i")), + call = call) + } + + has_base <- n_base > 0 + if (has_base) { + mvn_base <- make_rmvnorm(vcv$base) + proposal_base <- function(x, rng) { + ## This approach is likely to be a bit fragile, so we'll + ## probably want some naming related verification here soon too. + x[i_base] <- mvn_base(x[i_base], rng) + x + } + } else { + proposal_base <- NULL + } + + mvn_groups <- lapply(vcv$groups, make_rmvnorm) + proposal_groups <- function(x, rng) { + for (i in seq_len(n_groups)) { + x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) + } + x + } + + list(base = proposal_base, + groups = proposal_groups) +} +## nocov end