Skip to content

Commit

Permalink
type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoNeureiter committed Nov 17, 2022
1 parent 4a2f2a3 commit bebd352
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions sbayes/model/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def get_observation_lhs(
is_source = np.where(source.value.ravel())
return all_lh.ravel()[is_source]

def update_component_likelihoods(self, sample: Sample, caching=True) -> NDArray[float]:
def update_component_likelihoods(
self,
sample: Sample,
caching=True
) -> NDArray[float]: # shape: (n_objects, n_features, n_components)
"""Update the likelihood values for each of the mixture components"""

CHECK_CACHING = False

cache = sample.cache.component_likelihoods
Expand Down Expand Up @@ -144,7 +149,7 @@ def compute_component_likelihood(
groups: NDArray[bool], # (n_groups, n_sites)
changed_groups: set[int],
out: NDArray[float]
) -> NDArray[float]: # shape: (n_sites, n_features)
) -> NDArray[float]: # shape: (n_objects, n_features)
out[~groups.any(axis=0), :] = 0.
for i in changed_groups:
g = groups[i]
Expand Down

0 comments on commit bebd352

Please sign in to comment.