Skip to content

Commit

Permalink
stylr + lintr
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO committed Feb 3, 2024
1 parent 6598ba0 commit 25ea590
Show file tree
Hide file tree
Showing 16 changed files with 344 additions and 322 deletions.
120 changes: 60 additions & 60 deletions R/approach_vaeac.R

Large diffs are not rendered by default.

264 changes: 143 additions & 121 deletions R/approach_vaeac_extra_functions.R

Large diffs are not rendered by default.

90 changes: 45 additions & 45 deletions R/approach_vaeac_torch_modules.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@
#' @param activation_function A [torch::nn_module()] representing an activation function such as, e.g.,
#' [torch::nn_relu()], [torch::nn_leaky_relu()], [torch::nn_selu()],
#' [torch::nn_sigmoid()].
#' @param skip_connection_layer Boolean. If we are to use skip connections in each layer, see [shapr::SkipConnection()].
#' @param skip_conn_layer Boolean. If we are to use skip connections in each layer, see [shapr::SkipConnection()].
#' If `TRUE`, then we add the input to the outcome of each hidden layer, so the output becomes
#' \eqn{X + \operatorname{activation}(WX + b)}. I.e., the identity skip connection.
#' @param skip_connection_masked_enc_dec Boolean. If we are to apply concatenating skip
#' @param skip_conn_masked_enc_dec Boolean. If we are to apply concatenating skip
#' connections between the layers in the masked encoder and decoder. The first layer of the masked encoder will be
#' linked to the last layer of the decoder. The second layer of the masked encoder will be
#' linked to the second to last layer of the decoder, and so on.
#' @param batch_normalization Boolean. If we are to use batch normalization after the activation function.
#' Note that if `skip_connection_layer` is TRUE, then the normalization is
#' Note that if `skip_conn_layer` is TRUE, then the normalization is
#' done after the adding from the skip connection. I.e, we batch normalize the whole quantity X + activation(WX + b).
#' @param paired_sampling Boolean. If we are doing paired sampling. I.e., if we are to include both coalition S
#' and \eqn{\bar{S}} when we sample coalitions during training for each batch.
#' @param mask_generator_name String specifying the type of mask generator to use. Need to be one of
#' 'MCAR_mask_generator', 'Specified_prob_mask_generator', and 'Specified_masks_mask_generator'.
#' @param masking_ratio Scalar. The probability for an entry in the generated mask to be 1 (masked).
#' Not used if `mask_gen_these_coalitions` is given.
#' @param mask_gen_these_coalitions Matrix containing the different coalitions to learn.
#' Not used if `mask_gen_coalitions` is given.
#' @param mask_gen_coalitions Matrix containing the different coalitions to learn.
#' Must be given if `mask_generator_name = 'Specified_masks_mask_generator'`.
#' @param mask_gen_these_coalitions_prob Numerics containing the probabilities
#' for sampling each mask in `mask_gen_these_coalitions`.
#' Array containing the probabilities for sampling the coalitions in `mask_gen_these_coalitions`.
#' @param mask_gen_coalitions_prob Numerics containing the probabilities
#' for sampling each mask in `mask_gen_coalitions`.
#' Array containing the probabilities for sampling the coalitions in `mask_gen_coalitions`.
#' @param sigma_mu Numeric representing a hyperparameter in the normal-gamma prior used on the masked encoder,
#' see Section 3.3.1 in \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#' @param sigma_sigma Numeric representing a hyperparameter in the normal-gamma prior used on the masked encoder,
Expand Down Expand Up @@ -127,8 +127,8 @@ vaeac <- torch::nn_module(
depth = 3,
latent_dim = 8,
activation_function = torch::nn_relu,
skip_connection_layer = FALSE,
skip_connection_masked_enc_dec = FALSE,
skip_conn_layer = FALSE,
skip_conn_masked_enc_dec = FALSE,
batch_normalization = FALSE,
paired_sampling = FALSE,
mask_generator_name = c(
Expand All @@ -137,8 +137,8 @@ vaeac <- torch::nn_module(
"Specified_masks_mask_generator"
),
masking_ratio = 0.5,
mask_gen_these_coalitions = NULL,
mask_gen_these_coalitions_prob = NULL,
mask_gen_coalitions = NULL,
mask_gen_coalitions_prob = NULL,
sigma_mu = 1e4,
sigma_sigma = 1e-4) {
# Check that a valid mask_generator was provided.
Expand All @@ -149,7 +149,7 @@ vaeac <- torch::nn_module(

# Extra strings to add to names of layers depending on if we use memory layers and/or batch normalization.
# If FALSE, they are just an empty string and do not effect the names.
name_extra_memory_layer <- ifelse(skip_connection_masked_enc_dec, "_and_memory", "")
name_extra_memory_layer <- ifelse(skip_conn_masked_enc_dec, "_and_memory", "")
name_extra_batch_normalize <- ifelse(batch_normalization, "_and_batch_norm", "")

# Save some of the initializing hyperparameters to the vaeac object. Others are saved later.
Expand All @@ -158,8 +158,8 @@ vaeac <- torch::nn_module(
self$width <- width
self$latent_dim <- latent_dim
self$activation_function <- activation_function
self$skip_connection_layer <- skip_connection_layer
self$skip_connection_masked_enc_dec <- skip_connection_masked_enc_dec
self$skip_conn_layer <- skip_conn_layer
self$skip_conn_masked_enc_dec <- skip_conn_masked_enc_dec
self$batch_normalization <- batch_normalization
self$sigma_mu <- sigma_mu
self$sigma_sigma <- sigma_sigma
Expand Down Expand Up @@ -195,23 +195,23 @@ vaeac <- torch::nn_module(
self$masking_probs <- masking_ratio
} else if (mask_generator_name == "Specified_masks_mask_generator") {
# Small check that they have been provided.
if (is.null(mask_gen_these_coalitions) | is.null(mask_gen_these_coalitions_prob)) {
if (is.null(mask_gen_coalitions) | is.null(mask_gen_coalitions_prob)) {
stop(paste0(
"Both 'mask_gen_these_coalitions' and 'mask_gen_these_coalitions_prob' ",
"Both 'mask_gen_coalitions' and 'mask_gen_coalitions_prob' ",
"must be provided when using 'Specified_masks_mask_generator'."
))
}

# Create a Specified_masks_mask_generator and attach it to the vaeac object.
self$mask_generator <- Specified_masks_mask_generator(
masks = mask_gen_these_coalitions,
masks_probs = mask_gen_these_coalitions_prob,
masks = mask_gen_coalitions,
masks_probs = mask_gen_coalitions_prob,
paired_sampling = paired_sampling
)

# Save the possible masks and corresponding probabilities to the vaeac object.
self$masks <- mask_gen_these_coalitions
self$masks_probs <- mask_gen_these_coalitions_prob
self$masks <- mask_gen_coalitions
self$masks_probs <- mask_gen_coalitions_prob
} else {
# Print error to user.
stop(paste0(
Expand Down Expand Up @@ -248,7 +248,7 @@ vaeac <- torch::nn_module(

# Full Encoder: Hidden layers
for (i in seq(depth)) {
if (skip_connection_layer) {
if (skip_conn_layer) {
# Add identity skip connection. Such that the input is added to the output of the linear layer
# and activation function: output = X + activation(WX + b).
full_encoder_network$add_module(
Expand All @@ -257,7 +257,7 @@ vaeac <- torch::nn_module(
activation_function(),
if (batch_normalization) torch::nn_batch_norm1d(width)
),
name = paste0("hidden_layer_", i, "_skip_connection_with_linear_and_activation", name_extra_batch_normalize)
name = paste0("hidden_layer_", i, "_skip_conn_with_linear_and_activation", name_extra_batch_normalize)
)
} else {
# Do not use skip connections and do not add the input to the output.
Expand Down Expand Up @@ -292,7 +292,7 @@ vaeac <- torch::nn_module(
module = CategoricalToOneHotLayer(c(one_hot_max_sizes, rep(0, n_features))),
name = "input_layer_cat_to_one_hot"
)
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
masked_encoder_network$add_module(
module = MemoryLayer("#input"),
name = "input_layer_memory"
Expand All @@ -318,18 +318,18 @@ vaeac <- torch::nn_module(

# Masked Encoder: Hidden layers
for (i in seq(depth)) {
if (skip_connection_layer) {
if (skip_conn_layer) {
# Add identity skip connection. Such that the input is added to the output of the linear layer
# and activation function: output = X + activation(WX + b).
# Also check inside SkipConnection if we are to use MemoryLayer. I.e., skip connection with
# concatenation from masked encoder to decoder.
masked_encoder_network$add_module(
module = SkipConnection(
if (skip_connection_masked_enc_dec) MemoryLayer(paste0("#", i)),
if (skip_conn_masked_enc_dec) MemoryLayer(paste0("#", i)),
torch::nn_linear(width, width),
activation_function()
),
name = paste0("hidden_layer_", i, "_skip_connection_with_linear_and_activation", name_extra_memory_layer)
name = paste0("hidden_layer_", i, "_skip_conn_with_linear_and_activation", name_extra_memory_layer)
)
if (batch_normalization) {
masked_encoder_network$add_module(
Expand All @@ -339,7 +339,7 @@ vaeac <- torch::nn_module(
}
} else {
# Do not use skip connections and do not add the input to the output.
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
masked_encoder_network$add_module(
module = MemoryLayer(paste0("#", i)),
name = paste0("hidden_layer_", i, "_memory")
Expand All @@ -363,7 +363,7 @@ vaeac <- torch::nn_module(
}

# Masked Encoder: Go to latent space
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
masked_encoder_network$add_module(
module = MemoryLayer(paste0("#", depth + 1)),
name = "latent_space_layer_memory"
Expand Down Expand Up @@ -395,22 +395,22 @@ vaeac <- torch::nn_module(

# Get the width of the hidden layers in the decoder. Needs to be multiplied with two if
# we use skip connections between masked encoder and decoder as we concatenate the tensors.
width_decoder <- ifelse(skip_connection_masked_enc_dec, 2 * width, width)
width_decoder <- ifelse(skip_conn_masked_enc_dec, 2 * width, width)

# Same for the input dimension to the last layer in decoder that yields the distribution params.
extra_params_skip_con_mask_enc <-
ifelse(test = skip_connection_masked_enc_dec,
ifelse(test = skip_conn_masked_enc_dec,
yes = sum(apply(rbind(one_hot_max_sizes, rep(1, n_features)), 2, max)) + n_features,
no = 0
)

# Will need an extra hidden layer if we use skip connection from masked encoder to decoder
# as we send the full input layer of the masked encoder to the last layer in the decoder.
depth_decoder <- ifelse(skip_connection_masked_enc_dec, depth + 1, depth)
depth_decoder <- ifelse(skip_conn_masked_enc_dec, depth + 1, depth)

# Decoder: Hidden layers
for (i in seq(depth_decoder)) {
if (skip_connection_layer) {
if (skip_conn_layer) {
# Add identity skip connection. Such that the input is added to the output of the linear layer
# and activation function: output = X + activation(WX + b).
# Also check inside SkipConnection if we are to use MemoryLayer. I.e., skip connection with
Expand All @@ -423,14 +423,14 @@ vaeac <- torch::nn_module(
decoder_network$add_module(
module = torch::nn_sequential(
SkipConnection(
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
MemoryLayer(paste0("#", depth - i + 2), TRUE)
},
torch::nn_linear(width_decoder, width),
activation_function()
)
),
name = paste0("hidden_layer_", i, "_skip_connection_with_linear_and_activation", name_extra_memory_layer)
name = paste0("hidden_layer_", i, "_skip_conn_with_linear_and_activation", name_extra_memory_layer)
)
if (batch_normalization) {
decoder_network$add_module(
Expand All @@ -440,7 +440,7 @@ vaeac <- torch::nn_module(
}
} else {
# Do not use skip connections and do not add the input to the output.
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
decoder_network$add_module(
module = MemoryLayer(paste0("#", depth - i + 2), TRUE),
name = paste0("hidden_layer_", i, "_memory")
Expand All @@ -465,7 +465,7 @@ vaeac <- torch::nn_module(

# Decoder: Go the parameter space of the generative distributions
# Concatenate the input to the first layer of the masked encoder to the last layer of the decoder network.
if (skip_connection_masked_enc_dec) {
if (skip_conn_masked_enc_dec) {
decoder_network$add_module(
module = MemoryLayer("#input", TRUE),
name = "output_layer_memory"
Expand Down Expand Up @@ -1344,7 +1344,7 @@ SkipConnection <- torch::nn_module(
vaeac_extend_batch <- function(batch, dataloader, batch_size) {
# Check if the batch contains too few observations and in that case add the missing number of obs from a new batch
while (batch$shape[1] < batch_size) { # Use while in case a single extra batch is not enough to get to `batch_size`
batch_extra = dataloader$.iter()$.next()
batch_extra <- dataloader$.iter()$.next()
batch <- torch::torch_cat(c(batch, batch_extra[seq(min(nrow(batch_extra), batch_size - batch$shape[1])), ]), 1)
}

Expand All @@ -1371,17 +1371,17 @@ vaeac_extend_batch <- function(batch, dataloader, batch_size) {
#' @param mask_generator A mask generator object that generates the masks.
#' @param batch_size Integer. The number of samples to include in each batch.
#' @param vaeac_model The vaeac model.
#' @param validation_iwae_n_samples Number of samples to generate for computing the IWAE for each validation sample.
#' @param val_iwae_n_samples Number of samples to generate for computing the IWAE for each validation sample.
#'
#' @return The average iwae over all instances in the validation dataset.
#'
#' @author Lars Henry Berge Olsen
#' @keywords internal
vaeac_get_validation_iwae <- function(val_dataloader,
mask_generator,
batch_size,
vaeac_model,
validation_iwae_n_samples) {
vaeac_get_val_iwae <- function(val_dataloader,
mask_generator,
batch_size,
vaeac_model,
val_iwae_n_samples) {
# Set variables to store the number of instances evaluated and avg_iwae
cum_size <- 0
avg_iwae <- 0
Expand All @@ -1406,7 +1406,7 @@ vaeac_get_validation_iwae <- function(val_dataloader,
# Use torch::with_no_grad() since we are evaluation, and do not need the gradients to do backpropagation
torch::with_no_grad({
# Get the iwae for the first `init_size` observations in the batch. The other obs are just "padding".
iwae <- vaeac_model$batch_iwae(batch, mask, validation_iwae_n_samples)[1:init_size, drop = FALSE]
iwae <- vaeac_model$batch_iwae(batch, mask, val_iwae_n_samples)[1:init_size, drop = FALSE]

# Update the average iwae over all batches (over all instances). This is called recursive/online updating of
# the mean. Takes the old average * cum_size to get old sum of iwae and adds the sum of newly computed iwae.
Expand Down
24 changes: 12 additions & 12 deletions man/vaeac.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/vaeac_check_mask_gen.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 25ea590

Please sign in to comment.