Skip to content

Commit

Permalink
Fix: More stable sedynemo and fixed reinitialize function.
Browse files Browse the repository at this point in the history
  • Loading branch information
RukuangHuang committed Apr 17, 2023
1 parent 887e115 commit 4dd93bd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion osl_dynamics/inference/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def reinitialize_layer_weights(layer):
var = getattr(init_container, key.replace("_initializer", ""))

# Assign new random values to the variable
var.assign(new_initializer(var.shape, var.dtype))
if var is not None:
var.assign(new_initializer(var.shape, var.dtype))


def reinitialize_model_weights(model, keep=None):
Expand Down
8 changes: 4 additions & 4 deletions osl_dynamics/models/sedynemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _model_structure(config):
means_dev_map_layer = layers.Dense(config.n_channels, name="means_dev_map")

norm_means_dev_map_layer = layers.LayerNormalization(
axis=-1, name="norm_means_dev_map"
axis=-1, scale=False, name="norm_means_dev_map"
)

means_dev_mag_inf_alpha_input_layer = LearnableTensorLayer(
Expand All @@ -520,7 +520,7 @@ def _model_structure(config):
means_dev_mag_inf_beta_input_layer = LearnableTensorLayer(
shape=(config.n_subjects, config.n_modes, 1),
learn=config.learn_means,
initializer=initializers.TruncatedNormal(mean=0, stddev=0.02),
initializer=initializers.TruncatedNormal(mean=10, stddev=0.02),
name="means_dev_mag_inf_beta_input",
)
means_dev_mag_inf_beta_layer = layers.Activation(
Expand Down Expand Up @@ -591,7 +591,7 @@ def _model_structure(config):
config.n_channels * (config.n_channels + 1) // 2, name="covs_dev_map"
)
norm_covs_dev_map_layer = layers.LayerNormalization(
axis=-1, name="norm_covs_dev_map"
axis=-1, scale=False, name="norm_covs_dev_map"
)

covs_dev_mag_inf_alpha_input_layer = LearnableTensorLayer(
Expand All @@ -606,7 +606,7 @@ def _model_structure(config):
covs_dev_mag_inf_beta_input_layer = LearnableTensorLayer(
shape=(config.n_subjects, config.n_modes, 1),
learn=config.learn_covariances,
initializer=initializers.TruncatedNormal(mean=0, stddev=0.02),
initializer=initializers.TruncatedNormal(mean=10, stddev=0.02),
name="covs_dev_mag_inf_beta_input",
)
covs_dev_mag_inf_beta_layer = layers.Activation(
Expand Down

0 comments on commit 4dd93bd

Please sign in to comment.