Skip to content

Commit

Permalink
Monitor fractional occupancy during HMM training (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil8 authored Oct 20, 2024
1 parent b03c477 commit 410fdf0
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions osl_dynamics/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pqdm.threads import pqdm

import osl_dynamics.data.tf as dtf
from osl_dynamics.inference import initializers
from osl_dynamics.inference import initializers, modes
from osl_dynamics.inference.layers import (
CategoricalLogLikelihoodLossLayer,
CovarianceMatricesLayer,
Expand Down Expand Up @@ -200,6 +200,7 @@ def fit(
use_tqdm=False,
checkpoint_freq=None,
save_filepath=None,
dfo_tol=None,
verbose=1,
**kwargs,
):
Expand All @@ -223,6 +224,10 @@ def fit(
Frequency (in epochs) of saving model checkpoints.
save_filepath : str, optional
Path to save the model.
dfo_tol : float, optional
When the maximum fractional occupancy change (from epoch to epoch)
is less than this value, we stop the training. If :code:`None`
there is no early stopping.
verbose : int, optional
Verbosity level. :code:`0=silent`.
kwargs : keyword arguments, optional
Expand All @@ -232,8 +237,8 @@ def fit(
Returns
-------
history : dict
Dictionary with history of the loss and learning rates (:code:`lr`
and :code:`rho`).
Dictionary with history of the loss, learning rates (:code:`lr`
and :code:`rho`) and fractional occupancies during training.
"""
if epochs is None:
epochs = self.config.n_epochs
Expand All @@ -248,14 +253,17 @@ def fit(
checkpoint_dir = f"{save_filepath}/checkpoints"
checkpoint_prefix = f"{checkpoint_dir}/ckpt"

if dfo_tol is None:
dfo_tol = 0

# Make a TensorFlow Dataset
dataset = self.make_dataset(dataset, shuffle=True, concatenate=True)

# Set static loss scaling factor
self.set_static_loss_scaling_factor(dataset)

# Training curves
history = {"loss": [], "rho": [], "lr": []}
history = {"loss": [], "rho": [], "lr": [], "fo": [], "max_dfo": []}

# Loop through epochs
if use_tqdm:
Expand All @@ -279,6 +287,7 @@ def fit(

# Loop over batches
loss = []
occupancies = []
for data in dataset:
x = data["data"]

Expand All @@ -289,6 +298,10 @@ def fit(
if self.config.learn_trans_prob:
self.update_trans_prob(gamma, xi)

# Calculate fractional occupancy
stc = modes.argmax_time_courses(gamma)
occupancies.append(np.sum(stc, axis=0))

# Reshape gamma: (batch_size*sequence_length, n_states)
# -> (batch_size, sequence_length, n_states)
gamma = gamma.reshape(x.shape[0], x.shape[1], -1)
Expand Down Expand Up @@ -318,10 +331,27 @@ def fit(
history["rho"].append(self.rho)
history["lr"].append(lr)

occupancy = np.sum(occupancies, axis=0)
fo = occupancy / np.sum(occupancy)
history["fo"].append(fo)

# Save model checkpoint
if checkpoint_freq is not None and (n + 1) % checkpoint_freq == 0:
checkpoint.save(file_prefix=checkpoint_prefix)

# How much has the fractional occupancy changed?
if len(history["fo"]) == 1:
max_dfo = np.max(
np.abs(history["fo"][-1] - np.zeros_like(history["fo"][-1]))
)
else:
max_dfo = np.max(np.abs(history["fo"][-1] - history["fo"][-2]))
history["max_dfo"].append(max_dfo)
if dfo_tol > 0:
print(f"Max change in FO: {max_dfo}")
if max_dfo < dfo_tol:
break

if checkpoint_freq is not None:
np.save(f"{save_filepath}/trans_prob.npy", self.trans_prob)

Expand Down Expand Up @@ -1431,7 +1461,12 @@ def bayesian_information_criterion(self, dataset, loss_type="free_energy"):
return bic

def fine_tuning(
self, training_data, n_epochs=None, learning_rate=None, store_dir="tmp"
self,
training_data,
n_epochs=None,
learning_rate=None,
dfo_tol=None,
store_dir="tmp",
):
"""Fine tuning the model for each session.
Expand All @@ -1449,6 +1484,10 @@ def fine_tuning(
learning_rate : float, optional
Learning rate. Defaults to the value in the :code:`config` used
to create the model.
dfo_tol : float, optional
When the maximum fractional occupancy change (from epoch to epoch)
is less than this value, we stop the training. If :code:`None`
there is no early stopping.
store_dir : str, optional
Directory to temporarily store the model in.
Expand Down Expand Up @@ -1486,7 +1525,7 @@ def fine_tuning(
for i in trange(training_data.n_sessions, desc="Fine tuning"):
# Train on this session
with training_data.set_keep(i):
self.fit(training_data, verbose=0)
self.fit(training_data, dfo_tol=dfo_tol, verbose=0)
a = self.get_alpha(training_data, concatenate=True)

# Get the inferred parameters
Expand Down

0 comments on commit 410fdf0

Please sign in to comment.