Skip to content

Commit

Permalink
Refact: changed default learning rate for the HMM and save free energ…
Browse files Browse the repository at this point in the history
…y when training DyNeMo.
  • Loading branch information
cgohil8 committed Apr 17, 2023
1 parent 458497b commit 887e115
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions osl_dynamics/config_api/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def train_hmm(
{'sequence_length': 2000,
'batch_size': 32,
'learning_rate': 0.001,
'learning_rate': 0.01,
'n_epochs': 20}.
init_kwargs : dict
Keyword arguments to pass to :code:`Model.random_state_time_course_initialization`.
Expand Down Expand Up @@ -122,7 +122,7 @@ def train_hmm(
"n_channels": data.n_channels,
"sequence_length": 2000,
"batch_size": 32,
"learning_rate": 0.001,
"learning_rate": 0.01,
"n_epochs": 20,
}
config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs)
Expand Down Expand Up @@ -262,6 +262,9 @@ def train_dynemo(
# Training
history = model.fit(data, **fit_kwargs)

# Add free energy to the history object
history["free_energy"] = history["loss"][-1]

# Save trained model
_logger.info(f"Saving model to: {model_dir}")
model.save(model_dir)
Expand Down

0 comments on commit 887e115

Please sign in to comment.