Skip to content

Commit

Permalink
Add integration helper until #46 is merged
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Aug 21, 2024
1 parent bbdbbdf commit 7f2d7fe
Showing 1 changed file with 67 additions and 2 deletions.
69 changes: 67 additions & 2 deletions R/sampler-nested-adaptive.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv,
list(base = autocorrelation_base, groups = autocorrelation_groups)
internal$vcv <- list(base = vcv_base, groups = vcv_groups)
proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups)
internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups)
internal$proposal <- nested_proposal_adaptive(proposal_vcv,
model$parameter_groups)

internal$history_pars <- numeric()
internal$included <- integer()
Expand Down Expand Up @@ -311,7 +312,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv,

## Update proposal
proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups)
internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups)
internal$proposal <- nested_proposal_adaptive(proposal_vcv,
model$parameter_groups)

state
}
Expand Down Expand Up @@ -425,3 +427,66 @@ check_nested_adaptive <- function(x, n_groups, has_base, null_allowed = FALSE,

ret
}


## TODO: this is a simpler version of nested_proposal that does not
## cope with boundaries etc - that's being looked at in #46 for now.
## nocov start
nested_proposal_adaptive <- function(vcv, parameter_groups, call = NULL) {
i_base <- parameter_groups == 0
n_base <- sum(i_base)
n_groups <- max(parameter_groups)
i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i))
if (NROW(vcv$base) != n_base) {
cli::cli_abort(
c("Incompatible number of base parameters in your model and sampler",
i = paste("Your model has {n_base} base parameters, but 'vcv$base'",
"implies {NROW(vcv$base)} parameters")),
call = call)
}
if (length(vcv$groups) != n_groups) {
cli::cli_abort(
c("Incompatible number of parameter groups in your model and sampler",
i = paste("Your model has {n_groups} parameter groups, but",
"'vcv$groups' has {length(vcv$groups)} groups")),
call = call)
}
n_pars_by_group <- lengths(i_group)
n_pars_by_group_vcv <- vnapply(vcv$groups, nrow)
err <- n_pars_by_group_vcv != n_pars_by_group
if (any(err)) {
detail <- sprintf(
"Group %d has %d parameters but 'vcv$groups[[%d]]' has %d",
which(err), n_pars_by_group[err],
which(err), n_pars_by_group_vcv[err])
cli::cli_abort(
c("Incompatible number of parameters within parameter group",
set_names(detail, "i")),
call = call)
}

has_base <- n_base > 0
if (has_base) {
mvn_base <- make_rmvnorm(vcv$base)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
}
} else {
proposal_base <- NULL
}

mvn_groups <- lapply(vcv$groups, make_rmvnorm)
proposal_groups <- function(x, rng) {
for (i in seq_len(n_groups)) {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
}
x
}

list(base = proposal_base,
groups = proposal_groups)
}
## nocov end

0 comments on commit 7f2d7fe

Please sign in to comment.