Skip to content

Reuse compiled model methods across models #1083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ CmdStanFit$set("public", name = "init", value = init)
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (logical) Whether to show verbose logging during compilation.
#' @param hessian (logical) Whether to expose the (experimental) hessian method.
#' @param force_recompile (logical) Whether to recompile model methods, even if cached
#'
#' @examples
#' \dontrun{
Expand All @@ -335,14 +336,16 @@ CmdStanFit$set("public", name = "init", value = init)
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE) {
init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE, force_recompile = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
call. = FALSE)
}
require_suggested_package("Rcpp")
if (length(private$model_methods_env_$hpp_code_) == 0) {
if (length(private$model_methods_env_$hpp_code_) == 0 && (
is.null(private$model_methods_env_$obj_file_) ||
!file.exists(private$model_methods_env_$obj_file_))) {
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
"the model must be compiled again", call. = FALSE)
}
Expand All @@ -352,7 +355,7 @@ init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE) {
"errors that you encounter")
}
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
expose_model_methods(private$model_methods_env_, verbose, hessian, force_recompile = FALSE)
}
if (!("model_ptr_" %in% ls(private$model_methods_env_))) {
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
Expand Down
2 changes: 2 additions & 0 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ build_cmdstan <- function(dir,
clean_cmdstan <- function(dir = cmdstan_path(),
cores = getOption("mc.cores", 2),
quiet = FALSE) {
unlink(file.path(dir, "model_methods.o"))
unlink(file.path(dir, "model_methods.cpp"))
withr::with_envvar(
c("HOME" = short_path(Sys.getenv("HOME"))),
withr::with_path(
Expand Down
3 changes: 2 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ compile <- function(quiet = TRUE,
run_log <- wsl_compatible_run(
command = make_cmd(),
args = c(wsl_safe_path(repair_path(tmp_exe)),
cpp_options_to_compile_flags(cpp_options),
cpp_options_to_compile_flags(c(cpp_options, list("KEEP_OBJECT"="true", "CXXFLAGS += -fPIC"))),
stancflags_val),
wd = cmdstan_path(),
echo = !quiet || is_verbose_mode(),
Expand Down Expand Up @@ -735,6 +735,7 @@ compile <- function(quiet = TRUE,
file.remove(exe)
}
file.copy(tmp_exe, exe, overwrite = TRUE)
private$model_methods_env_$obj_file_ <- paste0(temp_file_no_ext, ".o")
if (os_is_wsl()) {
res <- processx::run(
command = "wsl",
Expand Down
112 changes: 88 additions & 24 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -785,44 +785,108 @@ check_sundials_fpic <- function(verbose) {
}
}

rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
with_cmdstan_flags <- function(expr) {
check_sundials_fpic(verbose)
cxxflags <- get_cmdstan_flags("CXXFLAGS")
cppflags <- get_cmdstan_flags("CPPFLAGS")
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
cmdstanr_includes <- paste0("-I", shQuote(cmdstanr_includes))

r_includes <- paste(
paste0("-I", shQuote(system.file("include", package = "Rcpp", mustWork = TRUE))),
paste0("-I", shQuote(R.home(component = "include")))
)

libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB", "SUNDIALS_TARGETS")
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = " ")
if (.Platform$OS.type == "windows") {
libs <- paste(libs, "-fopenmp")
}
if (cmdstan_version() <= "2.30.1") {
cppflags <- paste0(cppflags, " -DCMDSTAN_JSON")
if (os_is_windows()) {
libs <- paste(libs, "-fopenmp -lstdc++")
}
withr::with_path(repair_path(file.path(cmdstan_path(),"stan/lib/stan_math/lib/tbb")),
withr::with_makevars(
c(
USE_CXX14 = 1,
PKG_CPPFLAGS = cppflags,
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
PKG_LIBS = libs
),
Rcpp::sourceCpp(code = code, env = env, verbose = verbose, ...)
)
new_makevars <- c(
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste(cxxflags, cmdstanr_includes, r_includes, collapse = " "),
PKG_LIBS = libs
)
withr::with_path(
c(
repair_path(file.path(cmdstan_path(),"stan/lib/stan_math/lib/tbb")),
toolchain_PATH_env_var()
),
withr::with_makevars(new_makevars, expr)
)
}

rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
with_cmdstan_flags(Rcpp::sourceCpp(code = code, env = env, verbose = verbose, ...))
invisible(NULL)
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
initialize_method_functions <- function(env, so_name) {
env$model_ptr <-
function(...) { .Call("model_ptr_", ..., PACKAGE = so_name) }
env$log_prob <-
function(...) { .Call("log_prob_", ..., PACKAGE = so_name) }
env$grad_log_prob <-
function(...) { .Call("grad_log_prob_", ..., PACKAGE = so_name) }
env$hessian <-
function(...) { .Call("hessian_", ..., PACKAGE = so_name) }
env$get_num_upars <-
function(...) { .Call("get_num_upars_", ..., PACKAGE = so_name) }
env$get_param_metadata <-
function(...) { .Call("get_param_metadata_", ..., PACKAGE = so_name) }
env$unconstrain_variables <-
function(...) { .Call("unconstrain_variables_", ..., PACKAGE = so_name) }
env$unconstrain_draws <-
function(...) { .Call("unconstrain_draws_", ..., PACKAGE = so_name) }
env$constrain_variables <-
function(...) { .Call("constrain_variables_", ..., PACKAGE = so_name) }
env$unconstrained_param_names <-
function(...) { .Call("unconstrained_param_names_", ..., PACKAGE = so_name) }
env$constrained_param_names <-
function(...) { .Call("constrained_param_names_", ..., PACKAGE = so_name) }
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE, force_recompile = FALSE) {
precomp_methods_file <- file.path(cmdstan_path(), "model_methods.o")
if (file.exists(precomp_methods_file) && force_recompile) {
unlink(precomp_methods_file)
}
model_methods_cpp <- system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)
source_file <- paste0(strip_ext(precomp_methods_file), ".cpp")
file.copy(model_methods_cpp, source_file, overwrite = FALSE)

model_obj_file <- env$obj_file_
if (!file.exists(model_obj_file)) {
if (rlang::is_interactive()) {
message("Model object file not found, recompiling model...")
}
temp_hpp_file <- tempfile()
writeLines(env$hpp_code_, con = paste0(temp_hpp_file, ".cpp"))
model_obj_file <- paste0(temp_hpp_file, ".o")
}

if (!file.exists(precomp_methods_file) && rlang::is_interactive()) {
message("Compiling and caching additional model methods...")
}
if (rlang::is_interactive()) {
message("Compiling additional model methods...")
message("Linking precompiled model methods to model object file...")
}
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
methods_dll <- tempfile(fileext = .Platform$dynlib.ext)
with_cmdstan_flags(
processx::run(
command = file.path(R.home(component = "bin"), "R"),
args = c("CMD", "SHLIB", repair_path(model_obj_file), repair_path(precomp_methods_file),
"-o", repair_path(methods_dll)),
echo = verbose || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
error_on_status = FALSE
)
)

env$methods_dll_info <- with_cmdstan_flags(dyn.load(methods_dll, local = TRUE, now = TRUE))
initialize_method_functions(env, strip_ext(basename(methods_dll)))
invisible(NULL)
}

Expand Down
106 changes: 67 additions & 39 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ stan::model::model_base&
new_model(stan::io::var_context& data_context, unsigned int seed,
std::ostream* msg_stream);

// [[Rcpp::export]]
Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
RcppExport SEXP model_ptr_(SEXP data_path_, SEXP seed_) {
BEGIN_RCPP
std::string data_path = Rcpp::as<std::string>(data_path_);
boost::uint32_t seed = Rcpp::as<boost::uint32_t>(seed_);
Rcpp::XPtr<stan::model::model_base> ptr(
&new_model(*var_context(data_path), seed, &Rcpp::Rcout)
);
Expand All @@ -41,41 +43,48 @@ Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
Rcpp::Named("model_ptr") = ptr,
Rcpp::Named("base_rng") = base_rng
);
END_RCPP
}

// [[Rcpp::export]]
double log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jac_adjust) {
RcppExport SEXP log_prob_(SEXP ext_model_ptr, SEXP upars_, SEXP jacobian_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
if (jac_adjust) {
return stan::model::log_prob_propto<true>(*ptr.get(), upars, &Rcpp::Rcout);
double rtn;
Eigen::VectorXd upars = Rcpp::as<Eigen::VectorXd>(upars_);
if (Rcpp::as<bool>(jacobian_)) {
rtn = stan::model::log_prob_propto<true>(*ptr.get(), upars, &Rcpp::Rcout);
} else {
return stan::model::log_prob_propto<false>(*ptr.get(), upars, &Rcpp::Rcout);
rtn = stan::model::log_prob_propto<false>(*ptr.get(), upars, &Rcpp::Rcout);
}
return Rcpp::wrap(rtn);
END_RCPP
}

// [[Rcpp::export]]
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars,
bool jac_adjust) {
RcppExport SEXP grad_log_prob_(SEXP ext_model_ptr, SEXP upars_, SEXP jacobian_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::VectorXd gradients;
Eigen::VectorXd upars = Rcpp::as<Eigen::VectorXd>(upars_);

double lp;
if (jac_adjust) {
if (Rcpp::as<bool>(jacobian_)) {
lp = stan::model::log_prob_grad<true, true>(*ptr.get(), upars, gradients);
} else {
lp = stan::model::log_prob_grad<true, false>(*ptr.get(), upars, gradients);
}
Rcpp::NumericVector grad_rtn(Rcpp::wrap(std::move(gradients)));
grad_rtn.attr("log_prob") = lp;
return grad_rtn;
END_RCPP
}

// [[Rcpp::export]]
Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
RcppExport SEXP hessian_(SEXP ext_model_ptr, SEXP upars_, SEXP jacobian_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::VectorXd upars = Rcpp::as<Eigen::VectorXd>(upars_);

auto hessian_functor = [&](auto&& x) {
if (jacobian) {
if (Rcpp::as<bool>(jacobian_)) {
return ptr->log_prob<true, true>(x, 0);
} else {
return ptr->log_prob<true, false>(x, 0);
Expand All @@ -92,16 +101,18 @@ Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
Rcpp::Named("log_prob") = log_prob,
Rcpp::Named("grad_log_prob") = grad,
Rcpp::Named("hessian") = hessian);
END_RCPP
}

// [[Rcpp::export]]
size_t get_num_upars(SEXP ext_model_ptr) {
RcppExport SEXP get_num_upars_(SEXP ext_model_ptr) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
return ptr->num_params_r();
return Rcpp::wrap(ptr->num_params_r());
END_RCPP
}

// [[Rcpp::export]]
Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
RcppExport SEXP get_param_metadata_(SEXP ext_model_ptr) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<std::string> param_names;
std::vector<std::vector<size_t> > param_dims;
Expand All @@ -116,27 +127,31 @@ Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
}

return param_metadata;
END_RCPP
}

// [[Rcpp::export]]
Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variables) {
RcppExport SEXP unconstrain_variables_(SEXP ext_model_ptr, SEXP variables_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::VectorXd variables = Rcpp::as<Eigen::VectorXd>(variables_);
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables, unconstrained_variables, &Rcpp::Rcout);
return unconstrained_variables;
return Rcpp::wrap(unconstrained_variables);
END_RCPP
}

// [[Rcpp::export]]
Rcpp::List unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
RcppExport SEXP unconstrain_draws_(SEXP ext_model_ptr, SEXP variables_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::MatrixXd variables = Rcpp::as<Eigen::MatrixXd>(variables_);
// Need to do this for the first row to get the correct size of the unconstrained draws
Eigen::VectorXd unconstrained_draw1;
ptr->unconstrain_array(variables.row(0).transpose(), unconstrained_draw1, &Rcpp::Rcout);
std::vector<Eigen::VectorXd> unconstrained_draws(unconstrained_draw1.size());
for (auto&& unconstrained_par : unconstrained_draws) {
unconstrained_par.resize(variables.rows());
}

for (int i = 0; i < variables.rows(); i++) {
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
Expand All @@ -145,36 +160,49 @@ Rcpp::List unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
}
}
return Rcpp::wrap(unconstrained_draws);
END_RCPP
}

// [[Rcpp::export]]
std::vector<double> constrain_variables(SEXP ext_model_ptr, SEXP base_rng,
std::vector<double> upars,
bool return_trans_pars,
bool return_gen_quants) {
RcppExport SEXP constrain_variables_(SEXP ext_model_ptr, SEXP base_rng,
SEXP upars_,
SEXP return_trans_pars_,
SEXP return_gen_quants_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Rcpp::XPtr<stan::rng_t> rng(base_rng);
std::vector<double> upars = Rcpp::as<std::vector<double>>(upars_);
bool return_trans_pars = Rcpp::as<bool>(return_trans_pars_);
bool return_gen_quants = Rcpp::as<bool>(return_gen_quants_);
std::vector<int> params_i;
std::vector<double> vars;

ptr->write_array(*rng.get(), upars, params_i, vars, return_trans_pars, return_gen_quants);
return vars;
return Rcpp::wrap(vars);
END_RCPP
}

// [[Rcpp::export]]
std::vector<std::string> unconstrained_param_names(SEXP ext_model_ptr, bool return_trans_pars, bool return_gen_quants) {
RcppExport SEXP unconstrained_param_names_(SEXP ext_model_ptr,
SEXP return_trans_pars_,
SEXP return_gen_quants_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
bool return_trans_pars = Rcpp::as<bool>(return_trans_pars_);
bool return_gen_quants = Rcpp::as<bool>(return_gen_quants_);
std::vector<std::string> rtn_names;
ptr->unconstrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
return rtn_names;
return Rcpp::wrap(rtn_names);
END_RCPP
}

// [[Rcpp::export]]
std::vector<std::string> constrained_param_names(SEXP ext_model_ptr,
bool return_trans_pars,
bool return_gen_quants) {
RcppExport SEXP constrained_param_names_(SEXP ext_model_ptr,
SEXP return_trans_pars_,
SEXP return_gen_quants_) {
BEGIN_RCPP
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
bool return_trans_pars = Rcpp::as<bool>(return_trans_pars_);
bool return_gen_quants = Rcpp::as<bool>(return_gen_quants_);
std::vector<std::string> rtn_names;
ptr->constrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
return rtn_names;
return Rcpp::wrap(rtn_names);
END_RCPP
}
Loading
Loading