Skip to content

Commit

Permalink
Add alpha sampling for simplified State-DyNeMo (#296)
Browse files Browse the repository at this point in the history
* add: sample alpha state probabilities

* fix: typos

* fix: sampling process
  • Loading branch information
scho97 authored Oct 22, 2024
1 parent 410fdf0 commit 2bafe20
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
79 changes: 77 additions & 2 deletions osl_dynamics/models/simplified_state_dynemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers
from tqdm.auto import trange

import osl_dynamics.data.tf as dtf
from osl_dynamics.inference.layers import (
Expand Down Expand Up @@ -145,8 +147,81 @@ class Model(SimplifiedDyNeMo):

config_type = Config

def sample_alpha(self, n_samples):
raise NotImplementedError
def sample_alpha(self, n_samples, input_mean=0, input_std=1):
"""Uses the model RNN to sample a state probability time course, :code:`alpha`.
Parameters
----------
n_samples : int
Number of samples to take.
input_mean : float
Mean of the normal distribution to sample the initial input data from.
Defaults to 0.
input_std : float
Standard deviation of the normal distribution to sample the initial input
data from. Defaults to 1.
Returns
-------
alpha : np.ndarray
Sampled alpha.
"""
# Get layers
mod_rnn_layer = self.get_layer("mod_rnn")
theta_layer = self.get_layer("theta")
alpha_layer = self.get_layer("alpha")

# Sequence of the input data
input_data = np.zeros(
[self.config.sequence_length, self.config.n_channels],
dtype=np.float32,
)

# Randomly sample the first time step
input_data[-1] = np.random.normal(
loc=input_mean,
scale=input_std,
size=self.config.n_channels,
)

# Get observation model
mu = self.get_means() # shape: (n_states, n_channels)
D = self.get_covariances() # shape: (n_states, n_channels, n_channels)
mvns = [
tfp.distributions.MultivariateNormalTriL(
loc=tf.gather(mu, n, axis=-2),
scale_tril=tf.linalg.cholesky(tf.gather(D, n, axis=-3)),
allow_nan_stats=False,
)
for n in range(self.config.n_states)
]

# Sample the state probability time course
alpha = np.empty([n_samples, self.config.n_states], dtype=np.float32)
for i in trange(n_samples, desc="Sampling state probability time course"):
# If there are leading zeros we trim the state probabilities so that
# we don't pass the zeros
trimmed_input = input_data[~np.all(input_data == 0, axis=1)][
np.newaxis, :, :
]

# Predict the probability distribution function for theta one time
# step in the future
mod_rnn = mod_rnn_layer(trimmed_input)
theta = theta_layer(mod_rnn)[0, -1]

# Calculate the state probability time course
alpha[i] = alpha_layer(theta[np.newaxis, np.newaxis, :])[0, 0]

# Shift the input data one time step to the left
input_data = np.roll(input_data, -1, axis=0)

# Generate the next input data by sampling from the corresponding
# state-specific observation model
state = np.argmax(alpha[i])
input_data[-1] = mvns[state].sample() # shape: (n_channels,)

return alpha

def _model_structure(self):
"""Build the model structure."""
Expand Down
4 changes: 2 additions & 2 deletions osl_dynamics/models/state_dynemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def sample_alpha(self, n_samples, states=None):

# Sample the state probability time course
alpha = np.empty([n_samples, self.config.n_states], dtype=np.float32)
for i in trange(n_samples, desc="Sampling state time course"):
for i in trange(n_samples, desc="Sampling state probability time course"):
# If there are leading zeros we trim the state time course so that
# we don't pass the zeros
trimmed_states = states[~np.all(states == 0, axis=1)][np.newaxis, :, :]
Expand All @@ -237,7 +237,7 @@ def sample_alpha(self, n_samples, states=None):
# Sample from the probability distribution function
states[-1] = states_layer(mod_theta[np.newaxis, np.newaxis, :][0])

# Calculate the state time courses
# Calculate the state probability time courses
alpha[i] = alpha_layer(mod_theta[np.newaxis, np.newaxis, :])[0, 0]

return alpha
Expand Down

0 comments on commit 2bafe20

Please sign in to comment.