Skip to content

Commit

Permalink
Feat: display rho during HMM training.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil8 committed Dec 20, 2022
1 parent 2a61cdc commit 4154e20
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions osl_dynamics/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def build_model(self):
self.set_trans_prob(self.config.initial_trans_prob)
self.set_state_probs_t0(self.config.state_probs_t0)

def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs):
def fit(self, dataset, epochs=None, lr_decay=0.1, take=1, **kwargs):
"""Fit model to a dataset.
Iterates between:
Expand All @@ -150,7 +150,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs):
Returns
-------
history : dict
Dictionary with loss and rho history. Keys are 'loss' and 'rho'.
Dictionary with history of the loss and learning rates (lr and rho).
"""
if epochs is None:
epochs = self.config.n_epochs
Expand All @@ -170,6 +170,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs):
# If it's the last epoch, we train on the full dataset
take = 1

# Get the training data for this epoch
if take != 1:
dataset.shuffle(100000)
n_batches = max(round(n_total_batches * take), 1)
Expand Down Expand Up @@ -213,7 +214,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs):
print("\nTraining failed!")
return
loss.append(l)
pb_i.add(1, values=[("lr", lr), ("loss", l)])
pb_i.add(1, values=[("rho", self.rho), ("lr", lr), ("loss", l)])

history["loss"].append(np.mean(loss))
history["rho"].append(self.rho)
Expand Down Expand Up @@ -396,8 +397,6 @@ def _get_state_probs(self, x):
def _baum_welch(self, B, Pi_0, P):
"""Hidden state inference using the Baum-Welch algorithm.
This is a python implementation of the C++ library: https://github.com/OHBA-analysis/HMM-MAR/tree/master/utils/hidden_state_inference.
Parameters
----------
B : np.ndarray
Expand Down

0 comments on commit 4154e20

Please sign in to comment.