diff --git a/sbayes/model/likelihood.py b/sbayes/model/likelihood.py index 045244b5..96221f9b 100644 --- a/sbayes/model/likelihood.py +++ b/sbayes/model/likelihood.py @@ -79,8 +79,13 @@ def get_observation_lhs( is_source = np.where(source.value.ravel()) return all_lh.ravel()[is_source] - def update_component_likelihoods(self, sample: Sample, caching=True) -> NDArray[float]: + def update_component_likelihoods( + self, + sample: Sample, + caching=True + ) -> NDArray[float]: # shape: (n_objects, n_features, n_components) """Update the likelihood values for each of the mixture components""" + CHECK_CACHING = False cache = sample.cache.component_likelihoods @@ -144,7 +149,7 @@ def compute_component_likelihood( groups: NDArray[bool], # (n_groups, n_sites) changed_groups: set[int], out: NDArray[float] -) -> NDArray[float]: # shape: (n_sites, n_features) +) -> NDArray[float]: # shape: (n_objects, n_features) out[~groups.any(axis=0), :] = 0. for i in changed_groups: g = groups[i]