Skip to content

Commit

Permalink
Modularity: pull maximize_expectation out of the ExpectationMaximizer
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Sep 20, 2024
1 parent f159a88 commit 8825a23
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
2 changes: 1 addition & 1 deletion qunfold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.5-rc"
__version__ = "0.1.5-rc2"

from .methods.linear.losses import (
LeastSquaresLoss,
Expand Down
36 changes: 26 additions & 10 deletions qunfold/methods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,29 @@ def fit(self, X, y, n_classes=None):
self.classifier.fit(X, y)
return self
def predict(self, X):
pYX_pY = jnp.array(self.classifier.predict_proba(X) / self.p_trn) # P(Y|X) / P_trn(Y)
p_prev = jnp.array(self.p_trn) # the current estimate
for n_iter in range(self.max_iter):
pYX = pYX_pY * p_prev
pYX = pYX / pYX.sum(axis=1, keepdims=True)
p_next = pYX.mean(axis=0)
if jnp.linalg.norm(p_next - p_prev) < self.tol:
return Result(p_next, n_iter+1, "Optimization terminated successfully.")
p_prev = p_next
return Result(p_prev, self.max_iter, "Maximum number of iterations reached.")
return maximize_expectation(
jnp.array(self.classifier.predict_proba(X)), # P(Y|X)
jnp.array(self.p_trn),
self.max_iter,
self.tol,
)

def maximize_expectation(pYX, p_trn, max_iter=100, tol=1e-8):
"""The expectation maximization routine that is part of the `ExpectationMaximizer` by Saerens et al. (2002).
Args:
pYX: A JAX matrix of the posterior probabilities of a classifier, `P(Y|X)`. This matrix has to have the shape `(n_items, n_classes)`, as returned by some `classifier.predict_proba(X)`.
p_trn: A JAX array of prior probabilities of the classifier. This array has to have the shape `(n_classes,)`.
max_iter (optional): The maximum number of iterations. Defaults to `100`.
tol (optional): The convergence tolerance for the L2 norm between iterations. Defaults to `1e-8`.
"""
pYX_pY = pYX / p_trn # P(Y|X) / P_trn(Y)
p_prev = jnp.array(p_trn) # copy p_trn to get the first estimate
for n_iter in range(max_iter):
pYX = pYX_pY * p_prev
pYX = pYX / pYX.sum(axis=1, keepdims=True)
p_next = pYX.mean(axis=0)
if jnp.linalg.norm(p_next - p_prev) < tol:
return Result(p_next, n_iter+1, "Optimization terminated successfully.")
p_prev = p_next
return Result(p_prev, max_iter, "Maximum number of iterations reached.")

0 comments on commit 8825a23

Please sign in to comment.