Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial checkpoint and restore functionality. #178

Merged
merged 7 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Suggests:
testthat (>= 2.1.0),
xml2,
bench
RoxygenNote: 7.2.1.9000
RoxygenNote: 7.2.3
VignetteBuilder: knitr
LinkingTo:
Rcpp,
Expand Down
21 changes: 19 additions & 2 deletions R/categorical_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ CategoricalVariable <- R6Class(
#' @description return a character vector of possible values.
#' Note that the order of the returned vector may not be the same order
#' that was given when the variable was intitialized, due to the underlying
#' unordered storage type.
#' unordered storage type.
get_categories = function() {
categorical_variable_get_categories(self$.variable)
},
Expand Down Expand Up @@ -94,6 +94,23 @@ CategoricalVariable <- R6Class(
size = function() variable_get_size(self$.variable),

.update = function() variable_update(self$.variable),
.resize = function() variable_resize(self$.variable)
.resize = function() variable_resize(self$.variable),

.checkpoint = function() {
categories <- self$get_categories()
values <- lapply(categories, function(c) self$get_index_of(c)$to_vector())
names(values) <- categories
values
},

.restore = function(values) {
stopifnot(names(values) == self$get_categories())
stopifnot(sum(sapply(values, length)) == categorical_variable_get_size(self$.variable))

for (c in names(values)) {
self$queue_update(c, values[[c]])
}
self$.update()
}
)
)
9 changes: 8 additions & 1 deletion R/double_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ DoubleVariable <- R6Class(
size = function() variable_get_size(self$.variable),

.update = function() variable_update(self$.variable),
.resize = function() variable_resize(self$.variable)
.resize = function() variable_resize(self$.variable),

.checkpoint = function() self$get_values(),
.restore = function(values) {
stopifnot(length(values) == variable_get_size(self$.variable))
self$queue_update(values)
self$.update()
}
)
)
9 changes: 8 additions & 1 deletion R/integer_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ IntegerVariable <- R6Class(
size = function() variable_get_size(self$.variable),

.update = function() variable_update(self$.variable),
.resize = function() variable_resize(self$.variable)
.resize = function() variable_resize(self$.variable),

.checkpoint = function() self$get_values(),
.restore = function(values) {
stopifnot(length(values) == variable_get_size(self$.variable))
self$queue_update(values)
self$.update()
}
)
)
9 changes: 8 additions & 1 deletion R/ragged_double.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ RaggedDouble <- R6Class(
size = function() variable_get_size(self$.variable),

.update = function() variable_update(self$.variable),
.resize = function() variable_resize(self$.variable)
.resize = function() variable_resize(self$.variable),

.checkpoint = function() self$get_values(),
.restore = function(values) {
stopifnot(length(values) == variable_get_size(self$.variable))
self$queue_update(values)
self$.update()
}
)
)
9 changes: 8 additions & 1 deletion R/ragged_integer.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ RaggedInteger <- R6Class(
size = function() variable_get_size(self$.variable),

.update = function() variable_update(self$.variable),
.resize = function() variable_resize(self$.variable)
.resize = function() variable_resize(self$.variable),

.checkpoint = function() self$get_values(),
.restore = function(values) {
stopifnot(length(values) == variable_get_size(self$.variable))
self$queue_update(values)
self$.update()
}
)
)
58 changes: 55 additions & 3 deletions R/simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#' @param variables a list of Variables
#' @param events a list of Events
#' @param processes a list of processes to execute on each timestep
#' @param timesteps the number of timesteps to simulate
#' @param timesteps the end timestep of the simulation. If `state` is not NULL, timesteps must be greater than `state$timestep`
#' @param state a checkpoint from which to resume the simulation
#' @examples
#' population <- 4
#' timesteps <- 5
Expand Down Expand Up @@ -35,12 +36,22 @@ simulation_loop <- function(
variables = list(),
events = list(),
processes = list(),
timesteps
timesteps,
state = NULL
) {
if (timesteps <= 0) {
stop('End timestep must be > 0')
}
for (t in seq_len(timesteps)) {

start <- 1
if (!is.null(state)) {
start <- restore_state(state, variables, events)
if (start > timesteps) {
stop("Restored state is already longer than timesteps")
}
}

for (t in seq(start, timesteps)) {
for (process in processes) {
execute_any_process(process, t)
}
Expand All @@ -60,6 +71,47 @@ simulation_loop <- function(
event$.tick()
}
}

invisible(checkpoint_state(timesteps, variables, events))
}

#' @title Save the simulation state
#' @description Save the simulation state in an R object, allowing it to be
#' resumed later using \code{\link[individual]{restore_state}}.
#' @param timesteps <- the number of time steps that have already been simulated
#' @param variables the list of Variables
#' @param events the list of Events
checkpoint_state <- function(timesteps, variables, events) {
random_state <- .GlobalEnv$.Random.seed
list(
variables=lapply(variables, function(v) v$.checkpoint()),
timesteps=timesteps,
random_state=random_state
)
}

#' @title Restore the simulation state
#' @description Restore the simulation state from a previous checkpoint.
#' The state of passed events and variables is overwritten to match the state they
#' had when the simulation was checkpointed. Returns the time step at which the
#' simulation should resume.
#' @param state the simulation state to restore, as returned by \code{\link[individual]{restore_state}}.
#' @param variables the list of Variables
#' @param events the list of Events
restore_state <- function(state, variables, events) {
if (length(variables) != length(state$variables)) {
stop("Checkpoint's variables do not match simulation's")
}
for (i in seq_along(variables)) {
variables[[i]]$.restore(state$variables[[i]])
}
if (length(events) > 0) {
stop("Events cannot be restored yet")
}

.GlobalEnv$.Random.seed <- state$random_state

state$timesteps + 1
}

#' @title Execute a C++ or R process in the simulation
Expand Down
4 changes: 2 additions & 2 deletions man/Bitset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/CategoricalVariable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/DoubleVariable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/IntegerVariable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/RaggedDouble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/RaggedInteger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions man/checkpoint_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading