Skip to content

Commit

Permalink
Refact: avoid flake8 warning of undefined variable in hmm_poi evidenc…
Browse files Browse the repository at this point in the history
…e method.
  • Loading branch information
RukuangHuang committed Feb 1, 2024
1 parent ae5ad50 commit 735e388
Showing 1 changed file with 77 additions and 74 deletions.
151 changes: 77 additions & 74 deletions osl_dynamics/models/hmm_poi.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,34 +689,6 @@ def get_posterior_expected_prior(self, gamma, xi):

return first_term + remaining_terms

def _evidence_predict_step(self, log_smoothing_distribution):
"""Predict step for calculating the evidence.
.. math::
p(s_t=j | x_{1:t-1}) = \displaystyle\sum_i p(s_t = j | s_{t-1} = i)\
p(s_{t-1} = i | x_{1:t-1})
Parameters
----------
log_smoothing_distribution : np.ndarray
:math:`\log p(s_{t-1} | x_{1:t-1})`.
Shape is (batch_size, n_states).
Returns
-------
log_prediction_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states).
"""
log_trans_prob = np.expand_dims(np.log(self.trans_prob), 0)
log_smoothing_distribution = np.expand_dims(
log_smoothing_distribution,
axis=-1,
)
log_prediction_distribution = logsumexp(
log_trans_prob + log_smoothing_distribution, -2
)
return log_prediction_distribution

def get_log_likelihood(self, data):
"""Get the log-likelihood of data, :math:`\log p(x_t | s_t)`.
Expand All @@ -742,41 +714,6 @@ def get_log_likelihood(self, data):
)
return log_likelihood.numpy()

def _evidence_update_step(self, data, log_prediction_distribution):
"""Update step for calculating the evidence.
.. math::
p(s_t = j | x_{1:t}) &= \displaystyle\\frac{p(x_t | s_t = j) \
p(s_t = j | x_{1:t-1})}{p(x_t | x_{1:t-1})}
p(x_t | x_{1:t-1}) &= \displaystyle\sum_i p(x_t | s_t = j) \
p(s_t = i | x_{1:t-1})
Parameters
----------
data : np.ndarray
Data for the update step. Shape is (batch_size, n_channels).
log_prediction_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states).
Returns
-------
log_smoothing_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t})`. Shape is (batch_size, n_states).
predictive_log_likelihood : np.ndarray
:math:`\log p(x_t | x_{1:t-1})`. Shape is (batch_size,).
"""
log_likelihood = self.get_log_likelihood(data)
log_smoothing_distribution = log_likelihood + log_prediction_distribution
predictive_log_likelihood = logsumexp(log_smoothing_distribution, -1)

# Normalise the log smoothing distribution
log_smoothing_distribution -= np.expand_dims(
predictive_log_likelihood,
axis=-1,
)
return log_smoothing_distribution, predictive_log_likelihood

def get_stationary_distribution(self):
"""Get the stationary distribution of the Markov chain.
Expand Down Expand Up @@ -1022,6 +959,78 @@ def evidence(self, dataset):
evidence : float
Model evidence.
"""

# Helper functions
def _evidence_predict_step(log_smoothing_distribution=None):
"""Predict step for calculating the evidence.
.. math::
p(s_t=j | x_{1:t-1}) = \displaystyle\sum_i p(s_t = j | s_{t-1} = i)\
p(s_{t-1} = i | x_{1:t-1})
Parameters
----------
log_smoothing_distribution : np.ndarray
:math:`\log p(s_{t-1} | x_{1:t-1})`.
Shape is (batch_size, n_states).
Returns
-------
log_prediction_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states).
"""
if log_smoothing_distribution is None:
initial_distribution = self.get_stationary_distribution()
log_prediction_distribution = np.broadcast_to(
np.expand_dims(initial_distribution, axis=0),
(batch_size, self.config.n_states),
)
else:
log_trans_prob = np.expand_dims(np.log(self.trans_prob), 0)
log_smoothing_distribution = np.expand_dims(
log_smoothing_distribution,
axis=-1,
)
log_prediction_distribution = logsumexp(
log_trans_prob + log_smoothing_distribution, -2
)
return log_prediction_distribution

def _evidence_update_step(data, log_prediction_distribution):
"""Update step for calculating the evidence.
.. math::
p(s_t = j | x_{1:t}) &= \displaystyle\\frac{p(x_t | s_t = j) \
p(s_t = j | x_{1:t-1})}{p(x_t | x_{1:t-1})}
p(x_t | x_{1:t-1}) &= \displaystyle\sum_i p(x_t | s_t = j) \
p(s_t = i | x_{1:t-1})
Parameters
----------
data : np.ndarray
Data for the update step. Shape is (batch_size, n_channels).
log_prediction_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states).
Returns
-------
log_smoothing_distribution : np.ndarray
:math:`\log p(s_t | x_{1:t})`. Shape is (batch_size, n_states).
predictive_log_likelihood : np.ndarray
:math:`\log p(x_t | x_{1:t-1})`. Shape is (batch_size,).
"""
log_likelihood = self.get_log_likelihood(data)
log_smoothing_distribution = log_likelihood + log_prediction_distribution
predictive_log_likelihood = logsumexp(log_smoothing_distribution, -1)

# Normalise the log smoothing distribution
log_smoothing_distribution -= np.expand_dims(
predictive_log_likelihood,
axis=-1,
)
return log_smoothing_distribution, predictive_log_likelihood

_logger.info("Getting model evidence")
dataset = self.make_dataset(dataset, concatenate=True)
n_batches = dtf.get_n_batches(dataset)
Expand All @@ -1033,24 +1042,18 @@ def evidence(self, dataset):
pb_i = utils.Progbar(self.config.sequence_length)
batch_size = tf.shape(x)[0]
batch_evidence = np.zeros((batch_size))
log_smoothing_distribution = None
for t in range(self.config.sequence_length):
# Prediction step
if t == 0:
initial_distribution = self.get_stationary_distribution()
log_prediction_distribution = np.broadcast_to(
np.expand_dims(initial_distribution, axis=0),
(batch_size, self.config.n_states),
)
else:
log_prediction_distribution = self._evidence_predict_step(
log_smoothing_distribution
)
log_prediction_distribution = _evidence_predict_step(
log_smoothing_distribution
)

# Update step
(
log_smoothing_distribution,
predictive_log_likelihood,
) = self._evidence_update_step(x[:, t, :], log_prediction_distribution)
) = _evidence_update_step(x[:, t, :], log_prediction_distribution)

# Update the batch evidence
batch_evidence += predictive_log_likelihood
Expand Down

0 comments on commit 735e388

Please sign in to comment.