From 4154e20a8bdf50a47a8394f54772e24a518dfd23 Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Mon, 19 Dec 2022 20:48:21 +0000 Subject: [PATCH] Feat: display rho during HMM training. --- osl_dynamics/models/hmm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/osl_dynamics/models/hmm.py b/osl_dynamics/models/hmm.py index 0a7b4831..88bb5fd8 100644 --- a/osl_dynamics/models/hmm.py +++ b/osl_dynamics/models/hmm.py @@ -125,7 +125,7 @@ def build_model(self): self.set_trans_prob(self.config.initial_trans_prob) self.set_state_probs_t0(self.config.state_probs_t0) - def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs): + def fit(self, dataset, epochs=None, lr_decay=0.1, take=1, **kwargs): """Fit model to a dataset. Iterates between: @@ -150,7 +150,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs): Returns ------- history : dict - Dictionary with loss and rho history. Keys are 'loss' and 'rho'. + Dictionary with history of the loss and learning rates (lr and rho). """ if epochs is None: epochs = self.config.n_epochs @@ -170,6 +170,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs): # If it's the last epoch, we train on the full dataset take = 1 + # Get the training data for this epoch if take != 1: dataset.shuffle(100000) n_batches = max(round(n_total_batches * take), 1) @@ -213,7 +214,7 @@ def fit(self, dataset, epochs=None, lr_decay=0.075, take=1, **kwargs): print("\nTraining failed!") return loss.append(l) - pb_i.add(1, values=[("lr", lr), ("loss", l)]) + pb_i.add(1, values=[("rho", self.rho), ("lr", lr), ("loss", l)]) history["loss"].append(np.mean(loss)) history["rho"].append(self.rho) @@ -396,8 +397,6 @@ def _get_state_probs(self, x): def _baum_welch(self, B, Pi_0, P): """Hidden state inference using the Baum-Welch algorithm. - This is a python implementation of the C++ library: https://github.com/OHBA-analysis/HMM-MAR/tree/master/utils/hidden_state_inference. - Parameters ---------- B : np.ndarray