diff --git a/osl_dynamics/models/hmm_poi.py b/osl_dynamics/models/hmm_poi.py index e798a274..52fdc157 100644 --- a/osl_dynamics/models/hmm_poi.py +++ b/osl_dynamics/models/hmm_poi.py @@ -689,34 +689,6 @@ def get_posterior_expected_prior(self, gamma, xi): return first_term + remaining_terms - def _evidence_predict_step(self, log_smoothing_distribution): - """Predict step for calculating the evidence. - - .. math:: - p(s_t=j | x_{1:t-1}) = \displaystyle\sum_i p(s_t = j | s_{t-1} = i)\ - p(s_{t-1} = i | x_{1:t-1}) - - Parameters - ---------- - log_smoothing_distribution : np.ndarray - :math:`\log p(s_{t-1} | x_{1:t-1})`. - Shape is (batch_size, n_states). - - Returns - ------- - log_prediction_distribution : np.ndarray - :math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states). - """ - log_trans_prob = np.expand_dims(np.log(self.trans_prob), 0) - log_smoothing_distribution = np.expand_dims( - log_smoothing_distribution, - axis=-1, - ) - log_prediction_distribution = logsumexp( - log_trans_prob + log_smoothing_distribution, -2 - ) - return log_prediction_distribution - def get_log_likelihood(self, data): """Get the log-likelihood of data, :math:`\log p(x_t | s_t)`. @@ -742,41 +714,6 @@ def get_log_likelihood(self, data): ) return log_likelihood.numpy() - def _evidence_update_step(self, data, log_prediction_distribution): - """Update step for calculating the evidence. - - .. math:: - p(s_t = j | x_{1:t}) &= \displaystyle\\frac{p(x_t | s_t = j) \ - p(s_t = j | x_{1:t-1})}{p(x_t | x_{1:t-1})} - - p(x_t | x_{1:t-1}) &= \displaystyle\sum_i p(x_t | s_t = j) \ - p(s_t = i | x_{1:t-1}) - - Parameters - ---------- - data : np.ndarray - Data for the update step. Shape is (batch_size, n_channels). - log_prediction_distribution : np.ndarray - :math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states). - - Returns - ------- - log_smoothing_distribution : np.ndarray - :math:`\log p(s_t | x_{1:t})`. Shape is (batch_size, n_states). - predictive_log_likelihood : np.ndarray - :math:`\log p(x_t | x_{1:t-1})`. Shape is (batch_size,). - """ - log_likelihood = self.get_log_likelihood(data) - log_smoothing_distribution = log_likelihood + log_prediction_distribution - predictive_log_likelihood = logsumexp(log_smoothing_distribution, -1) - - # Normalise the log smoothing distribution - log_smoothing_distribution -= np.expand_dims( - predictive_log_likelihood, - axis=-1, - ) - return log_smoothing_distribution, predictive_log_likelihood - def get_stationary_distribution(self): """Get the stationary distribution of the Markov chain. @@ -1022,6 +959,78 @@ def evidence(self, dataset): evidence : float Model evidence. """ + + # Helper functions + def _evidence_predict_step(log_smoothing_distribution=None): + """Predict step for calculating the evidence. + + .. math:: + p(s_t=j | x_{1:t-1}) = \displaystyle\sum_i p(s_t = j | s_{t-1} = i)\ + p(s_{t-1} = i | x_{1:t-1}) + + Parameters + ---------- + log_smoothing_distribution : np.ndarray + :math:`\log p(s_{t-1} | x_{1:t-1})`. + Shape is (batch_size, n_states). + + Returns + ------- + log_prediction_distribution : np.ndarray + :math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states). + """ + if log_smoothing_distribution is None: + initial_distribution = self.get_stationary_distribution() + log_prediction_distribution = np.broadcast_to( + np.expand_dims(initial_distribution, axis=0), + (batch_size, self.config.n_states), + ) + else: + log_trans_prob = np.expand_dims(np.log(self.trans_prob), 0) + log_smoothing_distribution = np.expand_dims( + log_smoothing_distribution, + axis=-1, + ) + log_prediction_distribution = logsumexp( + log_trans_prob + log_smoothing_distribution, -2 + ) + return log_prediction_distribution + + def _evidence_update_step(data, log_prediction_distribution): + """Update step for calculating the evidence. + + .. math:: + p(s_t = j | x_{1:t}) &= \displaystyle\\frac{p(x_t | s_t = j) \ + p(s_t = j | x_{1:t-1})}{p(x_t | x_{1:t-1})} + + p(x_t | x_{1:t-1}) &= \displaystyle\sum_i p(x_t | s_t = j) \ + p(s_t = i | x_{1:t-1}) + + Parameters + ---------- + data : np.ndarray + Data for the update step. Shape is (batch_size, n_channels). + log_prediction_distribution : np.ndarray + :math:`\log p(s_t | x_{1:t-1})`. Shape is (batch_size, n_states). + + Returns + ------- + log_smoothing_distribution : np.ndarray + :math:`\log p(s_t | x_{1:t})`. Shape is (batch_size, n_states). + predictive_log_likelihood : np.ndarray + :math:`\log p(x_t | x_{1:t-1})`. Shape is (batch_size,). + """ + log_likelihood = self.get_log_likelihood(data) + log_smoothing_distribution = log_likelihood + log_prediction_distribution + predictive_log_likelihood = logsumexp(log_smoothing_distribution, -1) + + # Normalise the log smoothing distribution + log_smoothing_distribution -= np.expand_dims( + predictive_log_likelihood, + axis=-1, + ) + return log_smoothing_distribution, predictive_log_likelihood + _logger.info("Getting model evidence") dataset = self.make_dataset(dataset, concatenate=True) n_batches = dtf.get_n_batches(dataset) @@ -1033,24 +1042,18 @@ def evidence(self, dataset): pb_i = utils.Progbar(self.config.sequence_length) batch_size = tf.shape(x)[0] batch_evidence = np.zeros((batch_size)) + log_smoothing_distribution = None for t in range(self.config.sequence_length): # Prediction step - if t == 0: - initial_distribution = self.get_stationary_distribution() - log_prediction_distribution = np.broadcast_to( - np.expand_dims(initial_distribution, axis=0), - (batch_size, self.config.n_states), - ) - else: - log_prediction_distribution = self._evidence_predict_step( - log_smoothing_distribution - ) + log_prediction_distribution = _evidence_predict_step( + log_smoothing_distribution + ) # Update step ( log_smoothing_distribution, predictive_log_likelihood, - ) = self._evidence_update_step(x[:, t, :], log_prediction_distribution) + ) = _evidence_update_step(x[:, t, :], log_prediction_distribution) # Update the batch evidence batch_evidence += predictive_log_likelihood