diff --git a/osl_dynamics/inference/initializers.py b/osl_dynamics/inference/initializers.py index 0e81b84a..d0aed723 100644 --- a/osl_dynamics/inference/initializers.py +++ b/osl_dynamics/inference/initializers.py @@ -193,7 +193,8 @@ def reinitialize_layer_weights(layer): var = getattr(init_container, key.replace("_initializer", "")) # Assign new random values to the variable - var.assign(new_initializer(var.shape, var.dtype)) + if var is not None: + var.assign(new_initializer(var.shape, var.dtype)) def reinitialize_model_weights(model, keep=None): diff --git a/osl_dynamics/models/sedynemo.py b/osl_dynamics/models/sedynemo.py index c7fe59a5..8ef03471 100644 --- a/osl_dynamics/models/sedynemo.py +++ b/osl_dynamics/models/sedynemo.py @@ -505,7 +505,7 @@ def _model_structure(config): means_dev_map_layer = layers.Dense(config.n_channels, name="means_dev_map") norm_means_dev_map_layer = layers.LayerNormalization( - axis=-1, name="norm_means_dev_map" + axis=-1, scale=False, name="norm_means_dev_map" ) means_dev_mag_inf_alpha_input_layer = LearnableTensorLayer( @@ -520,7 +520,7 @@ def _model_structure(config): means_dev_mag_inf_beta_input_layer = LearnableTensorLayer( shape=(config.n_subjects, config.n_modes, 1), learn=config.learn_means, - initializer=initializers.TruncatedNormal(mean=0, stddev=0.02), + initializer=initializers.TruncatedNormal(mean=10, stddev=0.02), name="means_dev_mag_inf_beta_input", ) means_dev_mag_inf_beta_layer = layers.Activation( @@ -591,7 +591,7 @@ def _model_structure(config): config.n_channels * (config.n_channels + 1) // 2, name="covs_dev_map" ) norm_covs_dev_map_layer = layers.LayerNormalization( - axis=-1, name="norm_covs_dev_map" + axis=-1, scale=False, name="norm_covs_dev_map" ) covs_dev_mag_inf_alpha_input_layer = LearnableTensorLayer( @@ -606,7 +606,7 @@ def _model_structure(config): covs_dev_mag_inf_beta_input_layer = LearnableTensorLayer( shape=(config.n_subjects, config.n_modes, 1), learn=config.learn_covariances, - initializer=initializers.TruncatedNormal(mean=0, stddev=0.02), + initializer=initializers.TruncatedNormal(mean=10, stddev=0.02), name="covs_dev_mag_inf_beta_input", ) covs_dev_mag_inf_beta_layer = layers.Activation(